import os # Suppress tokenizer warning os.environ["TOKENIZERS_PARALLELISM"] = "false" import gradio as gr from transformers import AutoProcessor, AutoModelForImageTextToText import torch import spaces # Required for Zero-GPU # --- CONFIGURATION --- MODEL_ID = "google/medgemma-4b-it" MAX_CLINICAL_TOKENS = 256 # --- PROMPTS --- # OPTIMIZED HYBRID PROMPT (Best for Deterministic generation) SYSTEM_PROMPT_XRAY = """You are an AI assistant specialized in radiological image interpretation. Your role is to provide a structured, professional analysis to assist qualified healthcare professionals. **⚠️ CRITICAL DISCLAIMERS:** - You are an AI, NOT a radiologist. This analysis is for **educational/decision-support only**. - All findings must be verified by a qualified radiologist. - **Anti-Hallucination Protocol:** Do NOT hallucinate findings to match the provided clinical history if they are not clearly visible. Do NOT invent specific measurements (e.g., "2cm") unless a scale is clearly visible. **ANALYSIS APPROACH:** Analyze the image systematically using standard radiological methodology: 1. **Image Technical Quality:** Assess view, positioning, exposure, and limitations. 2. **Systematic Review:** - **Bones:** Cortex, medulla, alignment, fractures, lesions. - **Soft Tissues/Organs:** Swelling, masses, calcifications, organ silhouettes. - **Spaces/Joints:** Joint alignment, effusions, pneumothorax/air-fluid levels. - **Support Devices:** Tubes, lines, hardware (if present). 3. **Clinical Integration:** specifically search for correlates to the provided history, but report **only** what is visible. **OUTPUT FORMAT (Use Markdown `###` Headers):** ### 1. Technique & Quality - View(s) obtained and technical limitations. ### 2. Findings - Describe observations systematically by anatomical region. - Report **both** abnormal and pertinent normal findings. - Use precise anatomical terminology. - **Support Devices:** (Location of tubes/lines if present). ### ⚠️ CRITICAL ALERTS (If Applicable) - **Only** include this section for time-sensitive/life-threatening findings (e.g., Pneumothorax, Free Air). ### 3. Impression - Concise summary of key findings. - **Confidence Qualifier:** (e.g., "Findings are highly suggestive of...", "Probable...", "Cannot exclude..."). ### 4. Differential Diagnosis - List alternative considerations in order of likelihood. - Briefly explain the reasoning (features that favor or argue against each). ### 5. Recommendations - Follow-up imaging or clinical correlation. - **Urgency:** (Stat, Urgent, or Routine). - *Explicit Statement:* Must end with: "Clinical correlation is essential." """ SYSTEM_PROMPT_CHAT = """You are a knowledgeable medical assistant providing information and support to healthcare professionals and patients. **YOUR CAPABILITIES:** - Answer medical questions with evidence-based information - Explain diagnoses, treatments, and procedures in clear language - Help interpret medical terminology and reports - Provide general health education and wellness guidance - Assist with clinical decision support and differential diagnosis considerations **IMPORTANT LIMITATIONS:** - You do NOT provide definitive diagnoses or replace professional medical evaluation - You cannot prescribe medications or create treatment plans - Your knowledge has a cutoff date—always note when current information may have changed - You do not have access to individual patient records or test results unless explicitly shared **COMMUNICATION PRINCIPLES:** - Use clear, accessible language—adjust complexity based on the user (clinician vs. patient) - Provide evidence-based information with appropriate caveats about uncertainty - Be empathetic and professional, especially when discussing sensitive topics - Cite sources or note when recommendations are based on standard guidelines **SAFETY PROTOCOLS:** - For medical emergencies: immediately advise seeking emergency care (911/ER) - For urgent symptoms: recommend prompt evaluation by a healthcare provider - When uncertain: acknowledge limitations and suggest consulting with a specialist - Never discourage someone from seeking professional medical attention Adapt your tone and detail level based on whether you're speaking with healthcare professionals or patients.""" # --- GLOBAL MODEL LOADING --- print(f"⏳ Loading processor for {MODEL_ID}...") processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False) print(f"⏳ Loading model components...") try: model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Failed to load model: {e}") raise e # --- UTILITIES --- def count_tokens(text): if not text: return 0 return len(processor.tokenizer.encode(text, add_special_tokens=False)) def update_token_counter(clinical_info): tokens = count_tokens(clinical_info) if tokens > MAX_CLINICAL_TOKENS: return f"🔴 {tokens} / {MAX_CLINICAL_TOKENS} tokens", f"⚠️ Text will be truncated!" elif tokens > MAX_CLINICAL_TOKENS * 0.8: return f"🟡 {tokens} / {MAX_CLINICAL_TOKENS} tokens", "⚠️ Approaching token limit" else: return f"🟢 {tokens} / {MAX_CLINICAL_TOKENS} tokens", "" # --- INFERENCE FUNCTIONS --- @spaces.GPU(duration=30) def model_inference(messages, max_tokens=2048, temperature=0.4, do_sample=True): """ Generic inference function. NOTE: 'messages' must strictly follow the [{"role": "...", "content": [{"type":...}]}] format. """ try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] # Configure generation args based on sampling mode gen_kwargs = { "max_new_tokens": max_tokens, "do_sample": do_sample, } # Only add sampling parameters if sampling is enabled if do_sample: gen_kwargs["temperature"] = temperature gen_kwargs["top_p"] = 0.9 gen_kwargs["top_k"] = 50 with torch.inference_mode(): output = model.generate( **inputs, **gen_kwargs ) generated_ids = output[0] decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True) return decoded.strip() except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # --- X-RAY TAB LOGIC --- def generate_xray_report(image, clinical_info, history_state): if image is None: raise gr.Error("Please upload an X-ray image first.") # 1. Truncate Clinical Info (Token Safe) if clinical_info: input_ids = processor.tokenizer.encode(clinical_info, add_special_tokens=False) if len(input_ids) > MAX_CLINICAL_TOKENS: clinical_info = processor.tokenizer.decode(input_ids[:MAX_CLINICAL_TOKENS]) # 2. Build Initial User Message user_content = [] if clinical_info and clinical_info.strip(): user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"}) user_content.append({"type": "text", "text": "Describe this X-ray image."}) user_content.append({"type": "image", "image": image}) # 3. Construct Message History current_messages = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_XRAY}]}, {"role": "user", "content": user_content} ] # 4. Run Inference (DETERMINISTIC / GREEDY DECODING) # Using do_sample=False to ensure consistent, grounded clinical reports response_text = model_inference(current_messages, max_tokens=1280, do_sample=False) # 5. Update State current_messages.append({ "role": "model", "content": [{"type": "text", "text": response_text}] }) # 6. Update UI ui_history = [[None, response_text]] # Return clinical_info to keep it in the textbox (don't clear it) return ui_history, current_messages, clinical_info def chat_about_xray(user_text, history_state, ui_history): if not user_text.strip(): return ui_history, history_state, "" if not history_state: raise gr.Error("Please generate a report first.") # 1. Append User Question history_state.append({ "role": "user", "content": [{"type": "text", "text": user_text}] }) # 2. Run Inference (Sampling enabled, but temperature lowered to 0.4) # This allows conversational explanation while sticking to facts response_text = model_inference( history_state, max_tokens=1024, temperature=0.4, do_sample=True ) # 3. Update States history_state.append({ "role": "model", "content": [{"type": "text", "text": response_text}] }) ui_history.append([user_text, response_text]) return ui_history, history_state, "" # --- TEXT CHAT TAB LOGIC --- def medical_chat(user_text, history_state, ui_history): if not user_text.strip(): return ui_history, history_state, "" # Initialize state if empty if not history_state: history_state = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_CHAT}]} ] # Add user message history_state.append({ "role": "user", "content": [{"type": "text", "text": user_text}] }) # Run Inference (Sampling enabled, temperature lowered to 0.4) response_text = model_inference( history_state, max_tokens=1024, temperature=0.4, do_sample=True ) # Update state history_state.append({ "role": "model", "content": [{"type": "text", "text": response_text}] }) ui_history.append([user_text, response_text]) return ui_history, history_state, "" # --- UI CONSTRUCTION --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🏥 MedGemma Medical AI") gr.Markdown("Powered by **Google MedGemma-4B** with Zero-GPU.") with gr.Tabs(): # === TAB 1: X-RAY ANALYSIS === with gr.TabItem("🩻 X-Ray Analysis"): with gr.Row(): with gr.Column(scale=1): xray_image = gr.Image(type="pil", label="Upload X-ray", height=300) clinical_input = gr.Textbox( lines=3, placeholder="e.g. 65M, cough for 3 weeks...", label="Clinical Information" ) with gr.Row(): token_counter = gr.Textbox(value="0 / 256 tokens", show_label=False, interactive=False, container=False) token_warning = gr.Markdown("") generate_btn = gr.Button("🔬 Generate Report", variant="primary") with gr.Column(scale=2): # Internal state holds the full multimodal history xray_state = gr.State([]) xray_chatbot = gr.Chatbot(label="Radiology Report & Discussion", height=500, bubble_full_width=False) with gr.Row(): xray_chat_input = gr.Textbox( placeholder="Ask a follow-up question about the report...", show_label=False, scale=4 ) xray_send_btn = gr.Button("Send", scale=1) # Event Handlers clinical_input.change(fn=update_token_counter, inputs=[clinical_input], outputs=[token_counter, token_warning]) generate_btn.click( fn=generate_xray_report, inputs=[xray_image, clinical_input, xray_state], outputs=[xray_chatbot, xray_state, clinical_input] ) xray_chat_input.submit( fn=chat_about_xray, inputs=[xray_chat_input, xray_state, xray_chatbot], outputs=[xray_chatbot, xray_state, xray_chat_input] ) xray_send_btn.click( fn=chat_about_xray, inputs=[xray_chat_input, xray_state, xray_chatbot], outputs=[xray_chatbot, xray_state, xray_chat_input] ) # === TAB 2: MEDICAL ASSISTANT === with gr.TabItem("💬 Medical Assistant"): gr.Markdown("Chat with a helpful medical assistant (Text only).") chat_state = gr.State([]) chatbot = gr.Chatbot(height=500, bubble_full_width=False) with gr.Row(): chat_input = gr.Textbox(placeholder="Type your medical question here...", show_label=False, scale=4) chat_send_btn = gr.Button("Send", scale=1) chat_input.submit( fn=medical_chat, inputs=[chat_input, chat_state, chatbot], outputs=[chatbot, chat_state, chat_input] ) chat_send_btn.click( fn=medical_chat, inputs=[chat_input, chat_state, chatbot], outputs=[chatbot, chat_state, chat_input] ) # --- EXAMPLES --- examples = [ ["pneumonia.jpg", "Patient presenting with high fever, cough, and shortness of breath."], ["normal-chest-xray.png", "Routine checkup for 30-year-old male, no symptoms."], ["distal-radius-fracture.jpg", "30m, trauma, injury, pain"], ["distal-fibula-fracture.jpg", "30m patient that got injured playing soccer, acute pain, can not walk. "] ] gr.Examples( examples=examples, inputs=[xray_image, clinical_input], label="Try an X-Ray Example" ) if __name__ == "__main__": demo.launch()