Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| import torch | |
| # ------------------- | |
| # 1️⃣ Load Model & Processor (Now from Hugging Face) | |
| # ------------------- | |
| def load_model(): | |
| model_id = "Muhammadidrees/RaiyaChatDoc" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map="auto" # Let HF handle device placement | |
| ) | |
| model.to(device) | |
| return processor, model, device | |
| # Load model once at startup | |
| processor, model, device = load_model() | |
| # ------------------- | |
| # 2️⃣ Chat Logic Functions | |
| # ------------------- | |
| def process_message(message, history, question_count): | |
| """Process user message and generate doctor response""" | |
| if not message.strip(): | |
| return history, history, question_count | |
| history.append([message, None]) | |
| question_count += 1 | |
| should_analyze = ( | |
| question_count >= 6 or | |
| any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]) | |
| ) | |
| if should_analyze: | |
| system_prompt = ( | |
| "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis " | |
| "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, " | |
| "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. " | |
| "Format as numbered list with diagnosis name and short explanation." | |
| ) | |
| else: | |
| system_prompt = ( | |
| "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question " | |
| "to gather important diagnostic information. Keep it brief - just ask the question without explanations. " | |
| "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details." | |
| ) | |
| dialogue = [] | |
| for user_msg, bot_msg in history[:-1]: | |
| if user_msg: | |
| dialogue.append(f"Patient: {user_msg}") | |
| if bot_msg: | |
| dialogue.append(f"Doctor: {bot_msg}") | |
| dialogue.append(f"Patient: {message}") | |
| conversation = "\n".join(dialogue) | |
| prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:" | |
| inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) | |
| max_tokens = 400 if should_analyze else 25 | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| input_length = inputs["input_ids"].shape[1] | |
| generated_tokens = outputs[:, input_length:] | |
| response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip() | |
| if response.lower().startswith("doctor:"): | |
| response = response[7:].strip() | |
| if not should_analyze: | |
| sentences = response.split('?') | |
| if len(sentences) > 1: | |
| response = sentences[0].strip() + '?' | |
| cleanup_starts = [ | |
| "I need to ask", | |
| "Let me ask", | |
| "I would like to know", | |
| "Can you tell me", | |
| "It would help if", | |
| ] | |
| for phrase in cleanup_starts: | |
| if response.startswith(phrase): | |
| parts = response.split(',', 1) | |
| if len(parts) > 1: | |
| response = parts[1].strip() | |
| if not response.endswith('?'): | |
| response += '?' | |
| history[-1][1] = response | |
| if should_analyze: | |
| question_count = 0 | |
| return history, history, question_count | |
| def force_analysis(history, question_count): | |
| return history, 10 | |
| def clear_chat(): | |
| return [], [], 0 | |
| # ------------------- | |
| # 3️⃣ Gradio Interface | |
| # ------------------- | |
| with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo: | |
| question_count_state = gr.State(0) | |
| gr.Markdown( | |
| """ | |
| # 🩺 Chat with ChatDOC | |
| Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=400, | |
| show_label=False, | |
| avatar_images=( | |
| r"user_msg.png", | |
| r"bot_msg.jpg" | |
| ), | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Describe your symptoms...", | |
| scale=4, | |
| container=False, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| analysis_btn = gr.Button("Request Analysis", variant="secondary") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| def user_submit(message, history, question_count): | |
| return process_message(message, history, question_count) | |
| def clear_input(): | |
| return "" | |
| send_event = send_btn.click( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state], | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ).then( | |
| clear_input, | |
| outputs=[msg] | |
| ) | |
| msg.submit( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state], | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ).then( | |
| clear_input, | |
| outputs=[msg] | |
| ) | |
| analysis_btn.click( | |
| force_analysis, | |
| inputs=[chatbot, question_count_state], | |
| outputs=[chatbot, question_count_state] | |
| ) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="127.0.0.1", | |
| server_port=7860, | |
| share=False, | |
| debug=True | |
| ) | |