import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq import torch # ------------------- # 1️⃣ Load Model & Processor (Now from Hugging Face) # ------------------- def load_model(): model_id = "Muhammadidrees/RaiyaChatDoc" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForVision2Seq.from_pretrained( model_id, torch_dtype=dtype, device_map="auto" # Let HF handle device placement ) return processor, model, device # Load model once at startup processor, model, device = load_model() # ------------------- # 2️⃣ Chat Logic Functions # ------------------- def process_message(message, history, question_count): """Process user message and generate doctor response""" if not message.strip(): return history, history, question_count history.append([message, None]) question_count += 1 should_analyze = ( question_count >= 6 or any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]) ) if should_analyze: system_prompt = ( "You are a highly experienced medical expert who combines the roles of a medical doctor, specialist, nutritionist, and medical teacher.\n" "Based only on the patient's provided information, give a clear and structured analysis:\n\n" "1. Possible health issues or conditions the patient might have (3–4 points).\n" "2. Dietary and lifestyle recommendations specific to the patient’s situation.\n" "3. Guidance on which type of doctor or specialist the patient should consult.\n\n" "Be concise, professional, and easy to understand for a non-medical person. " "If you mention complex medical terms, briefly explain them in simple language." ) else: system_prompt = ( "You are a medical expert conducting a patient interview. Follow these rules:\n" "1. If the user simply shares symptoms or health info, ask ONE direct and specific medical question " "to gather diagnostic details (e.g., age, medical history, medications, lifestyle, family history, or symptoms). " "Do not explain, just ask the question.\n" "2. If the user explicitly asks for a diet plan, provide a complete, practical diet plan. " "Avoid unnecessary disclaimers, but keep it safe and balanced.\n" "3. If the user asks about a complex medical term, give a clear and simple explanation.\n\n" "Always keep responses brief, clear, and professional." ) dialogue = [] for user_msg, bot_msg in history[:-1]: if user_msg: dialogue.append(f"Patient: {user_msg}") if bot_msg: dialogue.append(f"Doctor: {bot_msg}") dialogue.append(f"Patient: {message}") conversation = "\n".join(dialogue) prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:" inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) max_tokens = 400 if should_analyze else 25 with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.1, pad_token_id=processor.tokenizer.eos_token_id, ) input_length = inputs["input_ids"].shape[1] generated_tokens = outputs[:, input_length:] response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip() if response.lower().startswith("doctor:"): response = response[7:].strip() if not should_analyze: sentences = response.split('?') if len(sentences) > 1: response = sentences[0].strip() + '?' cleanup_starts = [ "I need to ask", "Let me ask", "I would like to know", "Can you tell me", "It would help if", ] for phrase in cleanup_starts: if response.startswith(phrase): parts = response.split(',', 1) if len(parts) > 1: response = parts[1].strip() if not response.endswith('?'): response += '?' history[-1][1] = response if should_analyze: question_count = 0 return history, history, question_count def force_analysis(history, question_count): return history, 10 def clear_chat(): return [], [], 0 # ------------------- # 3️⃣ Gradio Interface # ------------------- with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo: question_count_state = gr.State(0) gr.Markdown( """ # 🩺 Chat with ChatDOC Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better. """ ) chatbot = gr.Chatbot( value=[], height=400, show_label=False, avatar_images=( r"user_msg.png", r"bot_msg.jpg" ), bubble_full_width=False ) with gr.Row(): msg = gr.Textbox( placeholder="Describe your symptoms...", scale=4, container=False, show_label=False ) send_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): analysis_btn = gr.Button("Request Analysis", variant="secondary") clear_btn = gr.Button("Clear Chat", variant="stop") def user_submit(message, history, question_count): return process_message(message, history, question_count) def clear_input(): return "" send_event = send_btn.click( user_submit, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state] ).then( clear_input, outputs=[msg] ) msg.submit( user_submit, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state] ).then( clear_input, outputs=[msg] ) analysis_btn.click( force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state] ) clear_btn.click( clear_chat, outputs=[chatbot, chatbot, question_count_state] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )