Spaces:
Sleeping
Sleeping
| import os | |
| # Suppress tokenizer warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import gradio as gr | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| import torch | |
| import spaces # Required for Zero-GPU | |
| # --- CONFIGURATION --- | |
| MODEL_ID = "google/medgemma-4b-it" | |
| MAX_CLINICAL_TOKENS = 256 | |
| # --- PROMPTS --- | |
| # OPTIMIZED HYBRID PROMPT (Best for Deterministic generation) | |
| SYSTEM_PROMPT_XRAY = """You are an AI assistant specialized in radiological image interpretation. Your role is to provide a structured, professional analysis to assist qualified healthcare professionals. | |
| **⚠️ CRITICAL DISCLAIMERS:** | |
| - You are an AI, NOT a radiologist. This analysis is for **educational/decision-support only**. | |
| - All findings must be verified by a qualified radiologist. | |
| - **Anti-Hallucination Protocol:** Do NOT hallucinate findings to match the provided clinical history if they are not clearly visible. Do NOT invent specific measurements (e.g., "2cm") unless a scale is clearly visible. | |
| **ANALYSIS APPROACH:** | |
| Analyze the image systematically using standard radiological methodology: | |
| 1. **Image Technical Quality:** Assess view, positioning, exposure, and limitations. | |
| 2. **Systematic Review:** | |
| - **Bones:** Cortex, medulla, alignment, fractures, lesions. | |
| - **Soft Tissues/Organs:** Swelling, masses, calcifications, organ silhouettes. | |
| - **Spaces/Joints:** Joint alignment, effusions, pneumothorax/air-fluid levels. | |
| - **Support Devices:** Tubes, lines, hardware (if present). | |
| 3. **Clinical Integration:** specifically search for correlates to the provided history, but report **only** what is visible. | |
| **OUTPUT FORMAT (Use Markdown `###` Headers):** | |
| ### 1. Technique & Quality | |
| - View(s) obtained and technical limitations. | |
| ### 2. Findings | |
| - Describe observations systematically by anatomical region. | |
| - Report **both** abnormal and pertinent normal findings. | |
| - Use precise anatomical terminology. | |
| - **Support Devices:** (Location of tubes/lines if present). | |
| ### ⚠️ CRITICAL ALERTS (If Applicable) | |
| - **Only** include this section for time-sensitive/life-threatening findings (e.g., Pneumothorax, Free Air). | |
| ### 3. Impression | |
| - Concise summary of key findings. | |
| - **Confidence Qualifier:** (e.g., "Findings are highly suggestive of...", "Probable...", "Cannot exclude..."). | |
| ### 4. Differential Diagnosis | |
| - List alternative considerations in order of likelihood. | |
| - Briefly explain the reasoning (features that favor or argue against each). | |
| ### 5. Recommendations | |
| - Follow-up imaging or clinical correlation. | |
| - **Urgency:** (Stat, Urgent, or Routine). | |
| - *Explicit Statement:* Must end with: "Clinical correlation is essential." | |
| """ | |
| SYSTEM_PROMPT_CHAT = """You are a knowledgeable medical assistant providing information and support to healthcare professionals and patients. | |
| **YOUR CAPABILITIES:** | |
| - Answer medical questions with evidence-based information | |
| - Explain diagnoses, treatments, and procedures in clear language | |
| - Help interpret medical terminology and reports | |
| - Provide general health education and wellness guidance | |
| - Assist with clinical decision support and differential diagnosis considerations | |
| **IMPORTANT LIMITATIONS:** | |
| - You do NOT provide definitive diagnoses or replace professional medical evaluation | |
| - You cannot prescribe medications or create treatment plans | |
| - Your knowledge has a cutoff date—always note when current information may have changed | |
| - You do not have access to individual patient records or test results unless explicitly shared | |
| **COMMUNICATION PRINCIPLES:** | |
| - Use clear, accessible language—adjust complexity based on the user (clinician vs. patient) | |
| - Provide evidence-based information with appropriate caveats about uncertainty | |
| - Be empathetic and professional, especially when discussing sensitive topics | |
| - Cite sources or note when recommendations are based on standard guidelines | |
| **SAFETY PROTOCOLS:** | |
| - For medical emergencies: immediately advise seeking emergency care (911/ER) | |
| - For urgent symptoms: recommend prompt evaluation by a healthcare provider | |
| - When uncertain: acknowledge limitations and suggest consulting with a specialist | |
| - Never discourage someone from seeking professional medical attention | |
| Adapt your tone and detail level based on whether you're speaking with healthcare professionals or patients.""" | |
| # --- GLOBAL MODEL LOADING --- | |
| print(f"⏳ Loading processor for {MODEL_ID}...") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False) | |
| print(f"⏳ Loading model components...") | |
| try: | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Failed to load model: {e}") | |
| raise e | |
| # --- UTILITIES --- | |
| def count_tokens(text): | |
| if not text: return 0 | |
| return len(processor.tokenizer.encode(text, add_special_tokens=False)) | |
| def update_token_counter(clinical_info): | |
| tokens = count_tokens(clinical_info) | |
| if tokens > MAX_CLINICAL_TOKENS: | |
| return f"🔴 {tokens} / {MAX_CLINICAL_TOKENS} tokens", f"⚠️ Text will be truncated!" | |
| elif tokens > MAX_CLINICAL_TOKENS * 0.8: | |
| return f"🟡 {tokens} / {MAX_CLINICAL_TOKENS} tokens", "⚠️ Approaching token limit" | |
| else: | |
| return f"🟢 {tokens} / {MAX_CLINICAL_TOKENS} tokens", "" | |
| # --- INFERENCE FUNCTIONS --- | |
| def model_inference(messages, max_tokens=2048, temperature=0.4, do_sample=True): | |
| """ | |
| Generic inference function. | |
| NOTE: 'messages' must strictly follow the [{"role": "...", "content": [{"type":...}]}] format. | |
| """ | |
| try: | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device, dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] | |
| # Configure generation args based on sampling mode | |
| gen_kwargs = { | |
| "max_new_tokens": max_tokens, | |
| "do_sample": do_sample, | |
| } | |
| # Only add sampling parameters if sampling is enabled | |
| if do_sample: | |
| gen_kwargs["temperature"] = temperature | |
| gen_kwargs["top_p"] = 0.9 | |
| gen_kwargs["top_k"] = 50 | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **inputs, | |
| **gen_kwargs | |
| ) | |
| generated_ids = output[0] | |
| decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True) | |
| return decoded.strip() | |
| except Exception as e: | |
| raise gr.Error(f"Generation failed: {str(e)}") | |
| # --- X-RAY TAB LOGIC --- | |
| def generate_xray_report(image, clinical_info, history_state): | |
| if image is None: | |
| raise gr.Error("Please upload an X-ray image first.") | |
| # 1. Truncate Clinical Info (Token Safe) | |
| if clinical_info: | |
| input_ids = processor.tokenizer.encode(clinical_info, add_special_tokens=False) | |
| if len(input_ids) > MAX_CLINICAL_TOKENS: | |
| clinical_info = processor.tokenizer.decode(input_ids[:MAX_CLINICAL_TOKENS]) | |
| # 2. Build Initial User Message | |
| user_content = [] | |
| if clinical_info and clinical_info.strip(): | |
| user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"}) | |
| user_content.append({"type": "text", "text": "Describe this X-ray image."}) | |
| user_content.append({"type": "image", "image": image}) | |
| # 3. Construct Message History | |
| current_messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_XRAY}]}, | |
| {"role": "user", "content": user_content} | |
| ] | |
| # 4. Run Inference (DETERMINISTIC / GREEDY DECODING) | |
| # Using do_sample=False to ensure consistent, grounded clinical reports | |
| response_text = model_inference(current_messages, max_tokens=1280, do_sample=False) | |
| # 5. Update State | |
| current_messages.append({ | |
| "role": "model", | |
| "content": [{"type": "text", "text": response_text}] | |
| }) | |
| # 6. Update UI | |
| ui_history = [[None, response_text]] | |
| # Return clinical_info to keep it in the textbox (don't clear it) | |
| return ui_history, current_messages, clinical_info | |
| def chat_about_xray(user_text, history_state, ui_history): | |
| if not user_text.strip(): | |
| return ui_history, history_state, "" | |
| if not history_state: | |
| raise gr.Error("Please generate a report first.") | |
| # 1. Append User Question | |
| history_state.append({ | |
| "role": "user", | |
| "content": [{"type": "text", "text": user_text}] | |
| }) | |
| # 2. Run Inference (Sampling enabled, but temperature lowered to 0.4) | |
| # This allows conversational explanation while sticking to facts | |
| response_text = model_inference( | |
| history_state, | |
| max_tokens=1024, | |
| temperature=0.4, | |
| do_sample=True | |
| ) | |
| # 3. Update States | |
| history_state.append({ | |
| "role": "model", | |
| "content": [{"type": "text", "text": response_text}] | |
| }) | |
| ui_history.append([user_text, response_text]) | |
| return ui_history, history_state, "" | |
| # --- TEXT CHAT TAB LOGIC --- | |
| def medical_chat(user_text, history_state, ui_history): | |
| if not user_text.strip(): | |
| return ui_history, history_state, "" | |
| # Initialize state if empty | |
| if not history_state: | |
| history_state = [ | |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_CHAT}]} | |
| ] | |
| # Add user message | |
| history_state.append({ | |
| "role": "user", | |
| "content": [{"type": "text", "text": user_text}] | |
| }) | |
| # Run Inference (Sampling enabled, temperature lowered to 0.4) | |
| response_text = model_inference( | |
| history_state, | |
| max_tokens=1024, | |
| temperature=0.4, | |
| do_sample=True | |
| ) | |
| # Update state | |
| history_state.append({ | |
| "role": "model", | |
| "content": [{"type": "text", "text": response_text}] | |
| }) | |
| ui_history.append([user_text, response_text]) | |
| return ui_history, history_state, "" | |
| # --- UI CONSTRUCTION --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🏥 MedGemma Medical AI") | |
| gr.Markdown("Powered by **Google MedGemma-4B** with Zero-GPU.") | |
| with gr.Tabs(): | |
| # === TAB 1: X-RAY ANALYSIS === | |
| with gr.TabItem("🩻 X-Ray Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| xray_image = gr.Image(type="pil", label="Upload X-ray", height=300) | |
| clinical_input = gr.Textbox( | |
| lines=3, | |
| placeholder="e.g. 65M, cough for 3 weeks...", | |
| label="Clinical Information" | |
| ) | |
| with gr.Row(): | |
| token_counter = gr.Textbox(value="0 / 256 tokens", show_label=False, interactive=False, container=False) | |
| token_warning = gr.Markdown("") | |
| generate_btn = gr.Button("🔬 Generate Report", variant="primary") | |
| with gr.Column(scale=2): | |
| # Internal state holds the full multimodal history | |
| xray_state = gr.State([]) | |
| xray_chatbot = gr.Chatbot(label="Radiology Report & Discussion", height=500, bubble_full_width=False) | |
| with gr.Row(): | |
| xray_chat_input = gr.Textbox( | |
| placeholder="Ask a follow-up question about the report...", | |
| show_label=False, | |
| scale=4 | |
| ) | |
| xray_send_btn = gr.Button("Send", scale=1) | |
| # Event Handlers | |
| clinical_input.change(fn=update_token_counter, inputs=[clinical_input], outputs=[token_counter, token_warning]) | |
| generate_btn.click( | |
| fn=generate_xray_report, | |
| inputs=[xray_image, clinical_input, xray_state], | |
| outputs=[xray_chatbot, xray_state, clinical_input] | |
| ) | |
| xray_chat_input.submit( | |
| fn=chat_about_xray, | |
| inputs=[xray_chat_input, xray_state, xray_chatbot], | |
| outputs=[xray_chatbot, xray_state, xray_chat_input] | |
| ) | |
| xray_send_btn.click( | |
| fn=chat_about_xray, | |
| inputs=[xray_chat_input, xray_state, xray_chatbot], | |
| outputs=[xray_chatbot, xray_state, xray_chat_input] | |
| ) | |
| # === TAB 2: MEDICAL ASSISTANT === | |
| with gr.TabItem("💬 Medical Assistant"): | |
| gr.Markdown("Chat with a helpful medical assistant (Text only).") | |
| chat_state = gr.State([]) | |
| chatbot = gr.Chatbot(height=500, bubble_full_width=False) | |
| with gr.Row(): | |
| chat_input = gr.Textbox(placeholder="Type your medical question here...", show_label=False, scale=4) | |
| chat_send_btn = gr.Button("Send", scale=1) | |
| chat_input.submit( | |
| fn=medical_chat, | |
| inputs=[chat_input, chat_state, chatbot], | |
| outputs=[chatbot, chat_state, chat_input] | |
| ) | |
| chat_send_btn.click( | |
| fn=medical_chat, | |
| inputs=[chat_input, chat_state, chatbot], | |
| outputs=[chatbot, chat_state, chat_input] | |
| ) | |
| # --- EXAMPLES --- | |
| examples = [ | |
| ["pneumonia.jpg", "Patient presenting with high fever, cough, and shortness of breath."], | |
| ["normal-chest-xray.png", "Routine checkup for 30-year-old male, no symptoms."], | |
| ["distal-radius-fracture.jpg", "30m, trauma, injury, pain"], | |
| ["distal-fibula-fracture.jpg", "30m patient that got injured playing soccer, acute pain, can not walk. "] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[xray_image, clinical_input], | |
| label="Try an X-Ray Example" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |