Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| import torch | |
| # ------------------- | |
| # 1️⃣ Load Model | |
| # ------------------- | |
| def load_model(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Load model and processor from Hugging Face | |
| processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True) | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| "Muhammadidrees/RaiyaChatDoc", | |
| torch_dtype=dtype, | |
| device_map="auto" # automatically assigns to GPU if available | |
| ) | |
| model.to(device) | |
| return processor, model, device | |
| processor, model, device = load_model() | |
| # ------------------- | |
| # 2️⃣ Chat Logic | |
| # ------------------- | |
| def process_message(message, history, question_count): | |
| if not message.strip(): | |
| return history, history, question_count | |
| history.append([message, None]) | |
| question_count += 1 | |
| # Decide if analysis is needed | |
| should_analyze = question_count >= 6 or any( | |
| word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"] | |
| ) | |
| # System prompt | |
| system_prompt = ( | |
| "You are a medical doctor. " | |
| "Provide a comprehensive analysis of potential causes for symptoms." | |
| if should_analyze else | |
| "You are a medical doctor conducting a patient interview. Ask ONE specific question." | |
| ) | |
| # Build conversation context | |
| dialogue = [] | |
| for user_msg, bot_msg in history[:-1]: | |
| if user_msg: dialogue.append(f"Patient: {user_msg}") | |
| if bot_msg: dialogue.append(f"Doctor: {bot_msg}") | |
| dialogue.append(f"Patient: {message}") | |
| prompt = f"{system_prompt}\n\nConversation:\n" + "\n".join(dialogue) + "\nDoctor:" | |
| # Prepare input | |
| inputs = processor(text=prompt, images=None, return_tensors="pt").to(device) | |
| max_tokens = 400 if should_analyze else 25 | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| pad_token_id=processor.tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| input_length = inputs["input_ids"].shape[1] | |
| response = processor.batch_decode(outputs[:, input_length:], skip_special_tokens=True)[0].strip() | |
| if response.lower().startswith("doctor:"): | |
| response = response[7:].strip() | |
| # Concise question formatting | |
| if not should_analyze: | |
| response = response.split('?')[0].strip() + '?' | |
| history[-1][1] = response | |
| if should_analyze: question_count = 0 | |
| return history, history, question_count | |
| def force_analysis(history, question_count): | |
| return history, 10 | |
| def clear_chat(): | |
| return [], [], 0 | |
| # ------------------- | |
| # 3️⃣ Gradio Interface | |
| # ------------------- | |
| with gr.Blocks(title="ChatDOC") as demo: | |
| question_count_state = gr.State(0) | |
| gr.Markdown("# 🩺 Chat with ChatDOC\nDescribe your symptoms and get guidance.") | |
| chatbot = gr.Chatbot(value=[], height=400, show_label=False) | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Describe your symptoms...", scale=4, container=False, show_label=False) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| analysis_btn = gr.Button("Request Analysis", variant="secondary") | |
| clear_btn = gr.Button("Clear Chat", variant="stop") | |
| send_event = send_btn.click( | |
| process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state] | |
| ).then(lambda: "", outputs=[msg]) | |
| msg.submit( | |
| process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state] | |
| ).then(lambda: "", outputs=[msg]) | |
| analysis_btn.click(force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state]) | |
| clear_btn.click(clear_chat, outputs=[chatbot, chatbot, question_count_state]) | |
| # ------------------- | |
| # 4️⃣ Launch | |
| # ------------------- | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True) | |