CultriX's picture
Update app.py
6c1d0b9 verified
import os
# Suppress tokenizer warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
import spaces # Required for Zero-GPU
# --- CONFIGURATION ---
MODEL_ID = "google/medgemma-4b-it"
MAX_CLINICAL_TOKENS = 256
# --- PROMPTS ---
# OPTIMIZED HYBRID PROMPT (Best for Deterministic generation)
SYSTEM_PROMPT_XRAY = """You are an AI assistant specialized in radiological image interpretation. Your role is to provide a structured, professional analysis to assist qualified healthcare professionals.
**⚠️ CRITICAL DISCLAIMERS:**
- You are an AI, NOT a radiologist. This analysis is for **educational/decision-support only**.
- All findings must be verified by a qualified radiologist.
- **Anti-Hallucination Protocol:** Do NOT hallucinate findings to match the provided clinical history if they are not clearly visible. Do NOT invent specific measurements (e.g., "2cm") unless a scale is clearly visible.
**ANALYSIS APPROACH:**
Analyze the image systematically using standard radiological methodology:
1. **Image Technical Quality:** Assess view, positioning, exposure, and limitations.
2. **Systematic Review:**
- **Bones:** Cortex, medulla, alignment, fractures, lesions.
- **Soft Tissues/Organs:** Swelling, masses, calcifications, organ silhouettes.
- **Spaces/Joints:** Joint alignment, effusions, pneumothorax/air-fluid levels.
- **Support Devices:** Tubes, lines, hardware (if present).
3. **Clinical Integration:** specifically search for correlates to the provided history, but report **only** what is visible.
**OUTPUT FORMAT (Use Markdown `###` Headers):**
### 1. Technique & Quality
- View(s) obtained and technical limitations.
### 2. Findings
- Describe observations systematically by anatomical region.
- Report **both** abnormal and pertinent normal findings.
- Use precise anatomical terminology.
- **Support Devices:** (Location of tubes/lines if present).
### ⚠️ CRITICAL ALERTS (If Applicable)
- **Only** include this section for time-sensitive/life-threatening findings (e.g., Pneumothorax, Free Air).
### 3. Impression
- Concise summary of key findings.
- **Confidence Qualifier:** (e.g., "Findings are highly suggestive of...", "Probable...", "Cannot exclude...").
### 4. Differential Diagnosis
- List alternative considerations in order of likelihood.
- Briefly explain the reasoning (features that favor or argue against each).
### 5. Recommendations
- Follow-up imaging or clinical correlation.
- **Urgency:** (Stat, Urgent, or Routine).
- *Explicit Statement:* Must end with: "Clinical correlation is essential."
"""
SYSTEM_PROMPT_CHAT = """You are a knowledgeable medical assistant providing information and support to healthcare professionals and patients.
**YOUR CAPABILITIES:**
- Answer medical questions with evidence-based information
- Explain diagnoses, treatments, and procedures in clear language
- Help interpret medical terminology and reports
- Provide general health education and wellness guidance
- Assist with clinical decision support and differential diagnosis considerations
**IMPORTANT LIMITATIONS:**
- You do NOT provide definitive diagnoses or replace professional medical evaluation
- You cannot prescribe medications or create treatment plans
- Your knowledge has a cutoff date—always note when current information may have changed
- You do not have access to individual patient records or test results unless explicitly shared
**COMMUNICATION PRINCIPLES:**
- Use clear, accessible language—adjust complexity based on the user (clinician vs. patient)
- Provide evidence-based information with appropriate caveats about uncertainty
- Be empathetic and professional, especially when discussing sensitive topics
- Cite sources or note when recommendations are based on standard guidelines
**SAFETY PROTOCOLS:**
- For medical emergencies: immediately advise seeking emergency care (911/ER)
- For urgent symptoms: recommend prompt evaluation by a healthcare provider
- When uncertain: acknowledge limitations and suggest consulting with a specialist
- Never discourage someone from seeking professional medical attention
Adapt your tone and detail level based on whether you're speaking with healthcare professionals or patients."""
# --- GLOBAL MODEL LOADING ---
print(f"⏳ Loading processor for {MODEL_ID}...")
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=False)
print(f"⏳ Loading model components...")
try:
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load model: {e}")
raise e
# --- UTILITIES ---
def count_tokens(text):
if not text: return 0
return len(processor.tokenizer.encode(text, add_special_tokens=False))
def update_token_counter(clinical_info):
tokens = count_tokens(clinical_info)
if tokens > MAX_CLINICAL_TOKENS:
return f"🔴 {tokens} / {MAX_CLINICAL_TOKENS} tokens", f"⚠️ Text will be truncated!"
elif tokens > MAX_CLINICAL_TOKENS * 0.8:
return f"🟡 {tokens} / {MAX_CLINICAL_TOKENS} tokens", "⚠️ Approaching token limit"
else:
return f"🟢 {tokens} / {MAX_CLINICAL_TOKENS} tokens", ""
# --- INFERENCE FUNCTIONS ---
@spaces.GPU(duration=30)
def model_inference(messages, max_tokens=2048, temperature=0.4, do_sample=True):
"""
Generic inference function.
NOTE: 'messages' must strictly follow the [{"role": "...", "content": [{"type":...}]}] format.
"""
try:
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
# Configure generation args based on sampling mode
gen_kwargs = {
"max_new_tokens": max_tokens,
"do_sample": do_sample,
}
# Only add sampling parameters if sampling is enabled
if do_sample:
gen_kwargs["temperature"] = temperature
gen_kwargs["top_p"] = 0.9
gen_kwargs["top_k"] = 50
with torch.inference_mode():
output = model.generate(
**inputs,
**gen_kwargs
)
generated_ids = output[0]
decoded = processor.decode(generated_ids[input_len:], skip_special_tokens=True)
return decoded.strip()
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
# --- X-RAY TAB LOGIC ---
def generate_xray_report(image, clinical_info, history_state):
if image is None:
raise gr.Error("Please upload an X-ray image first.")
# 1. Truncate Clinical Info (Token Safe)
if clinical_info:
input_ids = processor.tokenizer.encode(clinical_info, add_special_tokens=False)
if len(input_ids) > MAX_CLINICAL_TOKENS:
clinical_info = processor.tokenizer.decode(input_ids[:MAX_CLINICAL_TOKENS])
# 2. Build Initial User Message
user_content = []
if clinical_info and clinical_info.strip():
user_content.append({"type": "text", "text": f"Patient info: {clinical_info}"})
user_content.append({"type": "text", "text": "Describe this X-ray image."})
user_content.append({"type": "image", "image": image})
# 3. Construct Message History
current_messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_XRAY}]},
{"role": "user", "content": user_content}
]
# 4. Run Inference (DETERMINISTIC / GREEDY DECODING)
# Using do_sample=False to ensure consistent, grounded clinical reports
response_text = model_inference(current_messages, max_tokens=1280, do_sample=False)
# 5. Update State
current_messages.append({
"role": "model",
"content": [{"type": "text", "text": response_text}]
})
# 6. Update UI
ui_history = [[None, response_text]]
# Return clinical_info to keep it in the textbox (don't clear it)
return ui_history, current_messages, clinical_info
def chat_about_xray(user_text, history_state, ui_history):
if not user_text.strip():
return ui_history, history_state, ""
if not history_state:
raise gr.Error("Please generate a report first.")
# 1. Append User Question
history_state.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
# 2. Run Inference (Sampling enabled, but temperature lowered to 0.4)
# This allows conversational explanation while sticking to facts
response_text = model_inference(
history_state,
max_tokens=1024,
temperature=0.4,
do_sample=True
)
# 3. Update States
history_state.append({
"role": "model",
"content": [{"type": "text", "text": response_text}]
})
ui_history.append([user_text, response_text])
return ui_history, history_state, ""
# --- TEXT CHAT TAB LOGIC ---
def medical_chat(user_text, history_state, ui_history):
if not user_text.strip():
return ui_history, history_state, ""
# Initialize state if empty
if not history_state:
history_state = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_CHAT}]}
]
# Add user message
history_state.append({
"role": "user",
"content": [{"type": "text", "text": user_text}]
})
# Run Inference (Sampling enabled, temperature lowered to 0.4)
response_text = model_inference(
history_state,
max_tokens=1024,
temperature=0.4,
do_sample=True
)
# Update state
history_state.append({
"role": "model",
"content": [{"type": "text", "text": response_text}]
})
ui_history.append([user_text, response_text])
return ui_history, history_state, ""
# --- UI CONSTRUCTION ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏥 MedGemma Medical AI")
gr.Markdown("Powered by **Google MedGemma-4B** with Zero-GPU.")
with gr.Tabs():
# === TAB 1: X-RAY ANALYSIS ===
with gr.TabItem("🩻 X-Ray Analysis"):
with gr.Row():
with gr.Column(scale=1):
xray_image = gr.Image(type="pil", label="Upload X-ray", height=300)
clinical_input = gr.Textbox(
lines=3,
placeholder="e.g. 65M, cough for 3 weeks...",
label="Clinical Information"
)
with gr.Row():
token_counter = gr.Textbox(value="0 / 256 tokens", show_label=False, interactive=False, container=False)
token_warning = gr.Markdown("")
generate_btn = gr.Button("🔬 Generate Report", variant="primary")
with gr.Column(scale=2):
# Internal state holds the full multimodal history
xray_state = gr.State([])
xray_chatbot = gr.Chatbot(label="Radiology Report & Discussion", height=500, bubble_full_width=False)
with gr.Row():
xray_chat_input = gr.Textbox(
placeholder="Ask a follow-up question about the report...",
show_label=False,
scale=4
)
xray_send_btn = gr.Button("Send", scale=1)
# Event Handlers
clinical_input.change(fn=update_token_counter, inputs=[clinical_input], outputs=[token_counter, token_warning])
generate_btn.click(
fn=generate_xray_report,
inputs=[xray_image, clinical_input, xray_state],
outputs=[xray_chatbot, xray_state, clinical_input]
)
xray_chat_input.submit(
fn=chat_about_xray,
inputs=[xray_chat_input, xray_state, xray_chatbot],
outputs=[xray_chatbot, xray_state, xray_chat_input]
)
xray_send_btn.click(
fn=chat_about_xray,
inputs=[xray_chat_input, xray_state, xray_chatbot],
outputs=[xray_chatbot, xray_state, xray_chat_input]
)
# === TAB 2: MEDICAL ASSISTANT ===
with gr.TabItem("💬 Medical Assistant"):
gr.Markdown("Chat with a helpful medical assistant (Text only).")
chat_state = gr.State([])
chatbot = gr.Chatbot(height=500, bubble_full_width=False)
with gr.Row():
chat_input = gr.Textbox(placeholder="Type your medical question here...", show_label=False, scale=4)
chat_send_btn = gr.Button("Send", scale=1)
chat_input.submit(
fn=medical_chat,
inputs=[chat_input, chat_state, chatbot],
outputs=[chatbot, chat_state, chat_input]
)
chat_send_btn.click(
fn=medical_chat,
inputs=[chat_input, chat_state, chatbot],
outputs=[chatbot, chat_state, chat_input]
)
# --- EXAMPLES ---
examples = [
["pneumonia.jpg", "Patient presenting with high fever, cough, and shortness of breath."],
["normal-chest-xray.png", "Routine checkup for 30-year-old male, no symptoms."],
["distal-radius-fracture.jpg", "30m, trauma, injury, pain"],
["distal-fibula-fracture.jpg", "30m patient that got injured playing soccer, acute pain, can not walk. "]
]
gr.Examples(
examples=examples,
inputs=[xray_image, clinical_input],
label="Try an X-Ray Example"
)
if __name__ == "__main__":
demo.launch()