Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| import torch | |
| # ------------------- | |
| # 1️⃣ Load Model & Processor (Now from Hugging Face) | |
| # ------------------- | |
| def load_model(): | |
| model_id = "Muhammadidrees/RaiyaChatDoc" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| device_map="auto" # Let HF handle device placement | |
| ) | |
| return processor, model, device | |
| # Load model once at startup | |
| processor, model, device = load_model() | |
| # ------------------- | |
| # 2️⃣ Chat Logic Functions | |
| # ------------------- | |
| def process_message(message, history, question_count): | |
| """Process user message and generate doctor response""" | |
| if not message.strip(): | |
| return history, history, question_count | |
| history.append([message, None]) | |
| question_count += 1 | |
| should_analyze = ( | |
| question_count >= 6 or | |
| any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]) | |
| ) | |
| if should_analyze: | |
| system_prompt = ( | |
| "You are a highly experienced medical expert who combines the roles of a medical doctor, specialist, nutritionist, and medical teacher.\n" | |
| "Based only on the patient's provided information, give a clear and structured analysis:\n\n" | |
| "1. Possible health issues or conditions the patient might have (3–4 points).\n" | |
| "2. Dietary and lifestyle recommendations specific to the patient’s situation.\n" | |
| "3. Guidance on which type of doctor or specialist the patient should consult.\n\n" | |
| "Be concise, professional, and easy to understand for a non-medical person. " | |
| "If you mention complex medical terms, briefly explain them in simple language." | |
| ) | |
| else: | |
| system_prompt = ( | |
| "You are a medical expert conducting a patient interview. Follow these rules:\n" | |
| "1. If the user simply shares symptoms or health info, ask ONE direct and specific medical question " | |
| "to gather diagnostic details (e.g., age, medical history, medications, lifestyle, family history, or symptoms). " | |
| "Do not explain, just ask the question.\n" | |
| "2. If the user explicitly asks for a diet plan, provide a complete, practical diet plan. " | |
| "Avoid unnecessary disclaimers, but keep it safe and balanced.\n" | |
| "3. If the user asks about a complex medical term, give a clear and simple explanation.\n\n" | |
| "Always keep responses brief, clear, and professional." | |
| ) | |
| dialogue = [] | |
| for user_msg, bot_msg in history[:-1]: | |
| if user_msg: | |
| dialogue.append(f"Patient: {user_msg}") | |
| if bot_msg: | |
| dialogue.append(f"Doctor: {bot_msg}") | |
| dialogue.append(f"Patient: {message}") | |
| conversation = "\n".join(dialogue) | |
| prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:" | |
| inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) | |
| max_tokens = 400 if should_analyze else 25 | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=processor.tokenizer.eos_token_id, | |
| ) | |
| input_length = inputs["input_ids"].shape[1] | |
| generated_tokens = outputs[:, input_length:] | |
| response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip() | |
| if response.lower().startswith("doctor:"): | |
| response = response[7:].strip() | |
| if not should_analyze: | |
| sentences = response.split('?') | |
| if len(sentences) > 1: | |
| response = sentences[0].strip() + '?' | |
| cleanup_starts = [ | |
| "I need to ask", | |
| "Let me ask", | |
| "I would like to know", | |
| "Can you tell me", | |
| "It would help if", | |
| ] | |
| for phrase in cleanup_starts: | |
| if response.startswith(phrase): | |
| parts = response.split(',', 1) | |
| if len(parts) > 1: | |
| response = parts[1].strip() | |
| if not response.endswith('?'): | |
| response += '?' | |
| history[-1][1] = response | |
| if should_analyze: | |
| question_count = 0 | |
| return history, history, question_count | |
| def force_analysis(history, question_count): | |
| return history, 10 | |
| def clear_chat(): | |
| return [], [], 0 | |
| # ------------------- | |
| # 3️⃣ Gradio Interface | |
| # ------------------- | |
| with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo: | |
| question_count_state = gr.State(0) | |
| gr.Markdown( | |
| """ | |
| # 🩺 Chat with ChatDOC | |
| Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=400, | |
| show_label=False, | |
| avatar_images=( | |
| r"user_msg.png", | |
| r"bot_msg.jpg" | |
| ), | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Describe your symptoms...", | |
| scale=4, | |
| container=False, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| analysis_btn = gr.Button("Request Analysis", variant="secondary") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| def user_submit(message, history, question_count): | |
| return process_message(message, history, question_count) | |
| def clear_input(): | |
| return "" | |
| send_event = send_btn.click( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state], | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ).then( | |
| clear_input, | |
| outputs=[msg] | |
| ) | |
| msg.submit( | |
| user_submit, | |
| inputs=[msg, chatbot, question_count_state], | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ).then( | |
| clear_input, | |
| outputs=[msg] | |
| ) | |
| analysis_btn.click( | |
| force_analysis, | |
| inputs=[chatbot, question_count_state], | |
| outputs=[chatbot, question_count_state] | |
| ) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, chatbot, question_count_state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |