import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq import torch from PaitentVoiceToText import record_and_transcribe # Your STT function from DocVoice import text_to_speech # Your TTS function # ------------------- # 1️⃣ Load Model & Processor # ------------------- def load_model(): local_dir = r"C:\Users\JAY\Downloads\model\CHATDOCMODEL" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 processor = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( local_dir, dtype=dtype, device_map=None ) model.to(device) return processor, model, device processor, model, device = load_model() # ------------------- # 2️⃣ Chat Logic Functions # ------------------- def process_message(message, history, question_count): if not message.strip(): return history, history, question_count history.append([message, None]) question_count += 1 should_analyze = ( question_count >= 6 or any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]) ) if should_analyze: system_prompt = ( "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis " "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, " "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. " "Format as numbered list with diagnosis name and short explanation." ) else: system_prompt = ( "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question " "to gather important diagnostic information. Keep it brief - just ask the question without explanations. " "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details." ) dialogue = [] for user_msg, bot_msg in history[:-1]: if user_msg: dialogue.append(f"Patient: {user_msg}") if bot_msg: dialogue.append(f"Doctor: {bot_msg}") dialogue.append(f"Patient: {message}") conversation = "\n".join(dialogue) prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:" inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) max_tokens = 1000 if should_analyze else 25 with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.1, pad_token_id=processor.tokenizer.eos_token_id, ) input_length = inputs["input_ids"].shape[1] generated_tokens = outputs[:, input_length:] response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip() if response.lower().startswith("doctor:"): response = response[7:].strip() if not should_analyze: sentences = response.split('?') if len(sentences) > 1: response = sentences[0].strip() + '?' cleanup_starts = [ "I need to ask", "Let me ask", "I would like to know", "Can you tell me", "It would help if", ] for phrase in cleanup_starts: if response.startswith(phrase): parts = response.split(',', 1) if len(parts) > 1: response = parts[1].strip() if not response.endswith('?'): response += '?' history[-1][1] = response if should_analyze: question_count = 0 return history, history, question_count def force_analysis(history, question_count): return history, 10 def clear_chat(): return [], [], 0 # ------------------- # 3️⃣ TTS Helper # ------------------- def play_assistant_audio(response_text): if response_text: text_to_speech(response_text) return None # ------------------- # 4️⃣ Gradio Interface # ------------------- with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo: question_count_state = gr.State(0) assistant_responses_state = gr.State([]) gr.Markdown( """ # 🩺 Chat with ChatDOC Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better. """ ) chatbot = gr.Chatbot( value=[], height=400, show_label=False, avatar_images=( r"C:\Users\JAY\Downloads\model\user_msg.png", r"C:\Users\JAY\Downloads\model\bot_msg.jpg" ), bubble_full_width=False ) with gr.Row(): msg = gr.Textbox( placeholder="Describe your symptoms...", scale=4, container=False, show_label=False ) send_btn = gr.Button("Send", variant="primary", scale=1) mic_btn = gr.Button("🎤 Speak", variant="secondary", scale=1) with gr.Row(): analysis_btn = gr.Button("Request Analysis", variant="secondary") clear_btn = gr.Button("Clear Chat", variant="stop") play_audio_btn = gr.Button("🔊 Play Assistant Response", variant="secondary") # ------------------- # Update assistant responses # ------------------- def update_assistant_responses(history, assistant_responses): if history and history[-1][1]: assistant_responses.append(history[-1][1]) return assistant_responses # ------------------- # Submit handlers # ------------------- def user_submit(message, history, question_count, assistant_responses): history, updated_history, question_count = process_message(message, history, question_count) assistant_responses = update_assistant_responses(history, assistant_responses) return updated_history, updated_history, question_count, assistant_responses def mic_submit(history, question_count, assistant_responses): user_text = record_and_transcribe(duration=5) # Show user message immediately history.append([user_text, None]) history, updated_history, question_count = process_message(user_text, history, question_count) assistant_responses = update_assistant_responses(history, assistant_responses) return updated_history, updated_history, question_count, assistant_responses def clear_input(): return "" # ------------------- # Connect buttons # ------------------- send_btn.click( user_submit, inputs=[msg, chatbot, question_count_state, assistant_responses_state], outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] ).then(clear_input, outputs=[msg]) msg.submit( user_submit, inputs=[msg, chatbot, question_count_state, assistant_responses_state], outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] ).then(clear_input, outputs=[msg]) mic_btn.click( mic_submit, inputs=[chatbot, question_count_state, assistant_responses_state], outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] ) analysis_btn.click( force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state] ) clear_btn.click( clear_chat, outputs=[chatbot, chatbot, question_count_state] ) play_audio_btn.click( lambda assistant_responses: play_assistant_audio(assistant_responses[-1]) if assistant_responses else None, inputs=[assistant_responses_state], outputs=[] ) # ------------------- # 5️⃣ Launch # ------------------- if __name__ == "__main__": demo.launch( server_name="127.0.0.1", server_port=7860, share=False, debug=True )