Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from PaitentVoiceToText import record_and_transcribe | |
| from DocVoice import text_to_speech # Your TTS function | |
| # ------------------- | |
| # 1️⃣ Load Model & Processor | |
| # ------------------- | |
| def load_model(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Load directly from Hugging Face | |
| processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| "Muhammadidrees/RaiyaChatDoc", | |
| torch_dtype=dtype, | |
| device_map="auto" # automatically assigns to GPU if available | |
| ) | |
| model.to(device) | |
| return processor, model, device | |
| processor, model, device = load_model() | |
| # ------------------- | |
| # 2️⃣ Chat Logic Functions | |
| # ------------------- | |
| def process_message(message, history, question_count): | |
| if not message.strip(): | |
| return history, history, question_count | |
| history.append([message, None]) | |
| question_count += 1 | |
| should_analyze = ( | |
| question_count >= 6 or | |
| any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]) | |
| ) | |
| if should_analyze: | |
| system_prompt = ( | |
| "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis " | |
| "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, " | |
| "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. " | |
| "Format as numbered list with diagnosis name and short explanation." | |
| ) | |
| else: | |
| system_prompt = ( | |
| "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question " | |
| "to gather important diagnostic information. Keep it brief - just ask the question without explanations. " | |
| "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details." | |
| ) | |
| dialogue = [] | |
| for user_msg, bot_msg in history[:-1]: | |
| if user_msg: | |
| dialogue.append(f"Patient: {user_msg}") | |
| if bot_msg: | |
| dialogue.append(f"Doctor: {bot_msg}") | |
| dialogue.append(f"Patient: {message}") | |
| conversation = "\n".join(dialogue) | |
| prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:" | |
| inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) | |
| max_tokens = 1000 if should_analyze else 25 | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| input_length = inputs["input_ids"].shape[1] | |
| generated_tokens = outputs[:, input_length:] | |
| response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip() | |
| if response.lower().startswith("doctor:"): | |
| response = response[7:].strip() | |
| if not should_analyze: | |
| sentences = response.split('?') | |
| if len(sentences) > 1: | |
| response = sentences[0].strip() + '?' | |
| cleanup_starts = [ | |
| "I need to ask", | |
| "Let me ask", | |
| "I would like to know", | |
| "Can you tell me", | |
| "It would help if", | |
| ] | |
| for phrase in cleanup_starts: | |
| if response.startswith(phrase): | |
| parts = response.split(',', 1) | |
| if len(parts) > 1: | |
| response = parts[1].strip() | |
| if not response.endswith('?'): | |
| response += '?' | |
| history[-1][1] = response | |
| if should_analyze: | |
| question_count = 0 | |
| return history, history, question_count | |
| def force_analysis(history, question_count): | |
| return history, 10 | |
| def clear_chat(): | |
| return [], [], 0 | |
| # ------------------- | |
| # 3️⃣ TTS Helper | |
| # ------------------- | |
| def play_assistant_audio(response_text): | |
| if response_text: | |
| text_to_speech(response_text) | |
| return None | |
| # ------------------- | |
| # 4️⃣ Gradio Interface | |
| # ------------------- | |
| with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo: | |
| question_count_state = gr.State(0) | |
| assistant_responses_state = gr.State([]) | |
| gr.Markdown( | |
| """ | |
| # 🩺 Chat with ChatDOC | |
| Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=400, | |
| show_label=False, | |
| avatar_images=( | |
| r"C:\Users\JAY\Downloads\model\user_msg.png", | |
| r"C:\Users\JAY\Downloads\model\bot_msg.jpg" | |
| ), | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Describe your symptoms...", | |
| scale=4, | |
| container=False, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| mic_btn = gr.Button("🎤 Speak", variant="secondary", scale=1) | |
| with gr.Row(): | |
| analysis_btn = gr.Button("Request Analysis", variant="secondary") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| play_audio_btn = gr.Button("🔊 Play Assistant Response", variant="secondary") | |
| # ------------------- | |
| # Update assistant responses | |
| # ------------------- | |
| def update_assistant_responses(history, assistant_responses): | |
| if history and history[-1][1]: | |
| assistant_responses.append(history[-1][1]) | |
| return assistant_responses | |
| # ------------------- | |
| # Submit handlers | |
| # ------------------- | |
| def user_submit(message, history, question_count, assistant_responses): | |
| history, updated_history, question_count = process_message(message, history, question_count) | |
| assistant_responses = update_assistant_responses(history, assistant_responses) | |
| return updated_history, updated_history, question_count, assistant_responses | |
| def mic_submit(history, question_count, assistant_responses): | |
| user_text = record_and_transcribe(duration=5) | |
| history.append([user_text, None]) | |
| history, updated_history, question_count = process_message(user_text, history, question_count) | |
| assistant_responses = update_assistant_responses(history, assistant_responses) | |
| return updated_history, updated_history, question_count, assistant_responses | |
| def clear_input(): | |
| return "" | |
| # ------------------- | |
| # Connect buttons | |
| # ------------------- | |
| send_btn.click( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state, assistant_responses_state], | |
| outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] | |
| ).then(clear_input, outputs=[msg]) | |
| msg.submit( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state, assistant_responses_state], | |
| outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] | |
| ).then(clear_input, outputs=[msg]) | |
| mic_btn.click( | |
| mic_submit, | |
| inputs=[chatbot, question_count_state, assistant_responses_state], | |
| outputs=[chatbot, chatbot, question_count_state, assistant_responses_state] | |
| ) | |
| analysis_btn.click( | |
| force_analysis, | |
| inputs=[chatbot, question_count_state], | |
| outputs=[chatbot, question_count_state] | |
| ) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ) | |
| play_audio_btn.click( | |
| lambda assistant_responses: play_assistant_audio(assistant_responses[-1]) if assistant_responses else None, | |
| inputs=[assistant_responses_state], | |
| outputs=[] | |
| ) | |
| # ------------------- | |
| # 5️⃣ Launch | |
| # ------------------- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="127.0.0.1", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) | |