SkinProAI / models /medgemma_agent.py
cgoodmaker's picture
Speed up CPU inference: halve token limits, pre-download models, fix OMP threads
4af4003
"""
MedGemma Agent - LLM agent with tool calling and staged thinking feedback
Pipeline: MedGemma independent exam → Tools (MONET/ConvNeXt/GradCAM) → MedGemma reconciliation → Management
"""
import sys
import time
import random
import json
import os
import subprocess
import threading
from typing import Optional, Generator, Dict, Any
from PIL import Image
class MCPClient:
"""
Minimal MCP client that communicates with a FastMCP subprocess over stdio.
Uses raw newline-delimited JSON-RPC 2.0 so the main process (Python 3.9)
does not need the mcp library. The subprocess is launched with python3.11
which has mcp installed.
"""
def __init__(self):
self._process = None
self._lock = threading.Lock()
self._id_counter = 0
def _next_id(self) -> int:
self._id_counter += 1
return self._id_counter
def _send(self, obj: dict):
line = json.dumps(obj) + "\n"
self._process.stdin.write(line)
self._process.stdin.flush()
def _recv(self, timeout: int = 300) -> dict:
import select
deadline = time.time() + timeout
while True:
remaining = deadline - time.time()
if remaining <= 0:
raise RuntimeError(
f"MCP server did not respond within {timeout}s"
)
ready, _, _ = select.select(
[self._process.stdout], [], [], min(remaining, 5)
)
if not ready:
# Check if subprocess died
if self._process.poll() is not None:
raise RuntimeError(
f"MCP server exited with code {self._process.returncode}"
)
continue
line = self._process.stdout.readline()
if not line:
raise RuntimeError("MCP server closed connection unexpectedly")
line = line.strip()
if not line:
continue
msg = json.loads(line)
# Skip server-initiated notifications (no "id" key)
if "id" in msg:
return msg
def _initialize(self):
"""Send MCP initialize handshake."""
req_id = self._next_id()
self._send({
"jsonrpc": "2.0",
"id": req_id,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "SkinProAI", "version": "1.0.0"},
},
})
self._recv() # consume initialize response
# Confirm initialization
self._send({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
})
def start(self):
"""Spawn the MCP server subprocess and complete the handshake."""
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
server_script = os.path.join(root, "mcp_server", "server.py")
# Force MCP tool models (MONET, ConvNeXt) onto CPU so they don't
# compete with MedGemma for GPU VRAM (T4 has only 16 GB).
env = os.environ.copy()
env["SKINPRO_TOOL_DEVICE"] = "cpu"
# stderr inherits parent's stderr (not PIPE) to avoid deadlock:
# when the subprocess downloads models, progress output fills the
# 64KB pipe buffer and blocks the subprocess while the main process
# is blocked waiting on stdout — classic subprocess deadlock.
self._process = subprocess.Popen(
[sys.executable, server_script],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=None, # inherit parent stderr
text=True,
bufsize=1,
env=env,
)
self._initialize()
def call_tool_sync(self, tool_name: str, arguments: dict) -> dict:
"""Call a tool synchronously and return the parsed result dict."""
with self._lock:
req_id = self._next_id()
self._send({
"jsonrpc": "2.0",
"id": req_id,
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments},
})
response = self._recv()
# Protocol-level error (e.g. unknown method)
if "error" in response:
raise RuntimeError(
f"MCP tool '{tool_name}' failed: {response['error']}"
)
result = response["result"]
content_text = result["content"][0]["text"]
# Tool-level error (isError=True means the tool itself raised an exception)
if result.get("isError"):
raise RuntimeError(f"MCP tool '{tool_name}' error: {content_text}")
return json.loads(content_text)
def stop(self):
"""Terminate the MCP server subprocess."""
if self._process:
try:
self._process.stdin.close()
self._process.terminate()
self._process.wait(timeout=5)
except Exception:
pass
self._process = None
# Rotating verbs for spinner effect
ANALYSIS_VERBS = [
"Analyzing", "Examining", "Processing", "Inspecting", "Evaluating",
"Scanning", "Assessing", "Reviewing", "Studying", "Interpreting"
]
# Comprehensive visual exam prompt (combined from 4 separate stages)
COMPREHENSIVE_EXAM_PROMPT = """Perform a systematic dermoscopic examination of this skin lesion. Assess ALL of the following in a SINGLE concise analysis:
1. PATTERN: Overall architecture, symmetry (symmetric/asymmetric), organization
2. COLORS: List all colors present (brown, black, blue, white, red, pink) and distribution
3. BORDER: Sharp vs gradual, regular vs irregular, any disruptions
4. STRUCTURES: Pigment network, dots/globules, streaks, blue-white veil, regression, vessels
Then provide:
- Top 3 differential diagnoses with brief reasoning
- Concern level (1-5, where 5=urgent)
- Single most important feature driving your assessment
Be CONCISE - focus on clinically relevant findings only."""
def get_verb():
"""Get a random analysis verb for spinner effect"""
return random.choice(ANALYSIS_VERBS)
class MedGemmaAgent:
"""
Medical image analysis agent with:
- Staged thinking display (no emojis)
- Tool calling (MONET, ConvNeXt, Grad-CAM)
- Streaming responses
"""
def __init__(self, verbose: bool = True):
self.verbose = verbose
self.pipe = None
self.model_id = "google/medgemma-4b-it"
self.loaded = False
# Tools (legacy direct instances, kept for fallback / non-MCP use)
self.monet_tool = None
self.convnext_tool = None
self.gradcam_tool = None
self.rag_tool = None
self.tools_loaded = False
# MCP client
self.mcp_client = None
# State for confirmation flow
self.last_diagnosis = None
self.last_monet_result = None
self.last_image = None
self.last_medgemma_exam = None # Store independent MedGemma findings
self.last_reconciliation = None
def reset_state(self):
"""Reset analysis state for new analysis (keeps models loaded)"""
self.last_diagnosis = None
self.last_monet_result = None
self.last_image = None
self.last_medgemma_exam = None
self.last_reconciliation = None
def _print(self, message: str):
"""Print if verbose"""
if self.verbose:
print(message, flush=True)
def load_model(self):
"""Load MedGemma model"""
if self.loaded:
return
self._print("Initializing MedGemma agent...")
import os
import torch
from transformers import pipeline, AutoProcessor
# Authenticate with HF Hub (required for gated models like MedGemma)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
self._print("Authenticated with HF Hub")
else:
self._print("Warning: HF_TOKEN not set — gated models will fail")
self._print(f"Loading model: {self.model_id}")
if torch.cuda.is_available():
device = "cuda"
self._print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
device = "mps"
self._print("Using Apple Silicon (MPS)")
else:
device = "cpu"
self._print("Using CPU")
model_kwargs = dict(
dtype=torch.bfloat16, # bfloat16 on all devices (halves RAM on CPU: ~8GB vs ~16GB)
device_map="auto",
)
start = time.time()
processor = AutoProcessor.from_pretrained(
self.model_id, use_fast=True, token=hf_token,
)
self.pipe = pipeline(
"image-text-to-text",
model=self.model_id,
model_kwargs=model_kwargs,
token=hf_token, # pass explicitly in addition to login()
image_processor=processor.image_processor,
tokenizer=processor.tokenizer,
)
# Clear default max_length (20) from generation_config to avoid conflict
# with max_new_tokens passed at inference time
if hasattr(self.pipe.model, "generation_config"):
self.pipe.model.generation_config.max_length = None
self._print(f"Model loaded in {time.time() - start:.1f}s")
self.loaded = True
def load_tools(self):
"""Load tool models (MONET + ConvNeXt + Grad-CAM + RAG)"""
if self.tools_loaded:
return
from models.monet_tool import MonetTool
self.monet_tool = MonetTool()
self.monet_tool.load()
from models.convnext_classifier import ConvNeXtClassifier
self.convnext_tool = ConvNeXtClassifier()
self.convnext_tool.load()
from models.gradcam_tool import GradCAMTool
self.gradcam_tool = GradCAMTool(classifier=self.convnext_tool)
self.gradcam_tool.load()
from models.guidelines_rag import get_guidelines_rag
self.rag_tool = get_guidelines_rag()
if not self.rag_tool.loaded:
self.rag_tool.load_index()
self.tools_loaded = True
def load_tools_via_mcp(self):
"""Start the MCP server subprocess and mark tools as loaded."""
if self.tools_loaded:
return
self.mcp_client = MCPClient()
self.mcp_client.start()
self._print("MCP server started successfully")
self.tools_loaded = True
def _multi_pass_visual_exam(self, image, question: Optional[str] = None) -> Generator[str, None, Dict[str, str]]:
"""
MedGemma performs comprehensive visual examination BEFORE tools run.
Single prompt covers pattern, colors, borders, structures, and differentials.
Returns findings dict after yielding all output.
"""
findings = {}
yield f"\n[STAGE:medgemma_exam]MedGemma Visual Examination[/STAGE]\n"
yield f"[THINKING]Performing systematic dermoscopic assessment...[/THINKING]\n"
# Build prompt with optional clinical question
exam_prompt = COMPREHENSIVE_EXAM_PROMPT
if question:
exam_prompt += f"\n\nCLINICAL QUESTION: {question}"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": exam_prompt}
]
}
]
try:
time.sleep(0.2)
output = self.pipe(messages, max_new_tokens=200)
result = output[0]["generated_text"][-1]["content"]
findings['synthesis'] = result
yield f"[RESPONSE]\n"
words = result.split()
for i, word in enumerate(words):
time.sleep(0.015)
yield word + (" " if i < len(words) - 1 else "")
yield f"\n[/RESPONSE]\n"
except Exception as e:
findings['synthesis'] = f"Analysis failed: {e}"
yield f"[ERROR]Visual examination failed: {e}[/ERROR]\n"
self.last_medgemma_exam = findings
return findings
def _reconcile_findings(
self,
image,
medgemma_exam: Dict[str, str],
monet_result: Dict[str, Any],
convnext_result: Dict[str, Any],
question: Optional[str] = None
) -> Generator[str, None, None]:
"""
MedGemma reconciles its independent findings with tool outputs.
Identifies agreements, disagreements, and provides integrated assessment.
"""
yield f"\n[STAGE:reconciliation]Reconciling MedGemma Findings with Tool Results[/STAGE]\n"
yield f"[THINKING]Comparing independent visual assessment against AI classification tools...[/THINKING]\n"
top = convnext_result['predictions'][0]
runner_up = convnext_result['predictions'][1] if len(convnext_result['predictions']) > 1 else None
# Build MONET features string
monet_top = sorted(monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
monet_str = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in monet_top])
reconciliation_prompt = f"""You performed an independent visual examination of this lesion and concluded:
YOUR ASSESSMENT:
{medgemma_exam.get('synthesis', 'Not available')[:600]}
The AI classification tools produced these results:
- ConvNeXt classifier: {top['full_name']} ({top['probability']:.1%} confidence)
{f"- Runner-up: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
- Key MONET features: {monet_str}
{f'CLINICAL QUESTION: {question}' if question else ''}
Reconcile your visual findings with the AI classification:
1. AGREEMENT/DISAGREEMENT: Do your findings support the AI diagnosis? Any conflicts?
2. INTEGRATED ASSESSMENT: Final diagnosis considering all evidence
3. CONFIDENCE (1-10): How certain? What would change your assessment?
Be concise and specific."""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": reconciliation_prompt}
]
}
]
try:
output = self.pipe(messages, max_new_tokens=150)
reconciliation = output[0]["generated_text"][-1]["content"]
self.last_reconciliation = reconciliation
yield f"[RESPONSE]\n"
words = reconciliation.split()
for i, word in enumerate(words):
time.sleep(0.015)
yield word + (" " if i < len(words) - 1 else "")
yield f"\n[/RESPONSE]\n"
except Exception as e:
yield f"[ERROR]Reconciliation failed: {e}[/ERROR]\n"
def analyze_image_stream(
self,
image_path: str,
question: Optional[str] = None,
max_tokens: int = 512,
use_tools: bool = True
) -> Generator[str, None, None]:
"""
Stream analysis with new pipeline:
1. MedGemma independent multi-pass exam
2. MONET + ConvNeXt + GradCAM tools
3. MedGemma reconciliation
4. Confirmation request
"""
if not self.loaded:
yield "[STAGE:loading]Initializing MedGemma...[/STAGE]\n"
self.load_model()
yield f"[STAGE:image]{get_verb()} image...[/STAGE]\n"
try:
image = Image.open(image_path).convert("RGB")
self.last_image = image
except Exception as e:
yield f"[ERROR]Failed to load image: {e}[/ERROR]\n"
return
# Load tools early via MCP subprocess
if use_tools and not self.tools_loaded:
yield f"[STAGE:tools]Loading analysis tools...[/STAGE]\n"
self.load_tools_via_mcp()
# ===== PHASE 1: MedGemma Independent Visual Examination =====
medgemma_exam = {}
for chunk in self._multi_pass_visual_exam(image, question):
yield chunk
if isinstance(chunk, dict):
medgemma_exam = chunk
medgemma_exam = self.last_medgemma_exam or {}
monet_result = None
convnext_result = None
if use_tools:
# ===== PHASE 2: Run Classification Tools =====
yield f"\n[STAGE:tools_run]Running AI Classification Tools[/STAGE]\n"
yield f"[THINKING]Now running MONET and ConvNeXt to compare against visual examination...[/THINKING]\n"
# MONET Feature Extraction
time.sleep(0.2)
yield f"\n[STAGE:monet]MONET Feature Extraction[/STAGE]\n"
try:
monet_result = self.mcp_client.call_tool_sync(
"monet_analyze", {"image_path": image_path}
)
self.last_monet_result = monet_result
yield f"[TOOL_OUTPUT:MONET Features]\n"
for name, score in monet_result["features"].items():
short_name = name.replace("MONET_", "").replace("_", " ").title()
bar_filled = int(score * 10)
bar = "|" + "=" * bar_filled + "-" * (10 - bar_filled) + "|"
yield f" {short_name}: {bar} {score:.0%}\n"
yield f"[/TOOL_OUTPUT]\n"
except Exception as e:
yield f"[ERROR]MONET failed: {e}[/ERROR]\n"
# ConvNeXt Classification
time.sleep(0.2)
yield f"\n[STAGE:convnext]ConvNeXt Classification[/STAGE]\n"
try:
monet_scores = monet_result["vector"] if monet_result else None
convnext_result = self.mcp_client.call_tool_sync(
"classify_lesion",
{
"image_path": image_path,
"monet_scores": monet_scores,
},
)
self.last_diagnosis = convnext_result
yield f"[TOOL_OUTPUT:Classification Results]\n"
for pred in convnext_result["predictions"][:5]:
prob = pred['probability']
bar_filled = int(prob * 20)
bar = "|" + "=" * bar_filled + "-" * (20 - bar_filled) + "|"
yield f" {pred['class']}: {bar} {prob:.1%}\n"
yield f" {pred['full_name']}\n"
yield f"[/TOOL_OUTPUT]\n"
top = convnext_result['predictions'][0]
yield f"[RESULT]ConvNeXt Primary: {top['full_name']} ({top['probability']:.1%})[/RESULT]\n"
except Exception as e:
yield f"[ERROR]ConvNeXt failed: {e}[/ERROR]\n"
# Grad-CAM Visualization
time.sleep(0.2)
yield f"\n[STAGE:gradcam]Grad-CAM Attention Map[/STAGE]\n"
try:
gradcam_result = self.mcp_client.call_tool_sync(
"generate_gradcam", {"image_path": image_path}
)
gradcam_path = gradcam_result["gradcam_path"]
yield f"[GRADCAM_IMAGE:{gradcam_path}]\n"
except Exception as e:
yield f"[ERROR]Grad-CAM failed: {e}[/ERROR]\n"
# ===== PHASE 3: MedGemma Reconciliation =====
if convnext_result and monet_result and medgemma_exam:
for chunk in self._reconcile_findings(
image, medgemma_exam, monet_result, convnext_result, question
):
yield chunk
# ===== PHASE 4: Auto Management Guidance (with RAG) =====
if convnext_result and self.mcp_client:
for chunk in self.generate_management_guidance():
yield chunk
def generate_management_guidance(
self,
user_confirmed: bool = True,
user_feedback: Optional[str] = None
) -> Generator[str, None, None]:
"""
Generate LESION-SPECIFIC management guidance using RAG + MedGemma reasoning.
References specific findings from this analysis, not generic textbook management.
"""
if not self.last_diagnosis:
yield "[ERROR]No diagnosis available. Please analyze an image first.[/ERROR]\n"
return
top = self.last_diagnosis['predictions'][0]
runner_up = self.last_diagnosis['predictions'][1] if len(self.last_diagnosis['predictions']) > 1 else None
diagnosis = top['full_name']
if not user_confirmed and user_feedback:
yield f"[THINKING]Clinician provided alternative assessment: {user_feedback}[/THINKING]\n"
diagnosis = user_feedback
# Stage: RAG Search
time.sleep(0.3)
yield f"\n[STAGE:guidelines]Searching clinical guidelines for {diagnosis}...[/STAGE]\n"
# Get RAG context via MCP
features_desc = self.last_monet_result.get('description', '') if self.last_monet_result else ''
rag_data = self.mcp_client.call_tool_sync(
"search_guidelines",
{"query": features_desc, "diagnosis": diagnosis},
)
context = rag_data["context"]
references = rag_data["references"]
# Check guideline relevance
has_relevant_guidelines = False
if references:
diagnosis_lower = diagnosis.lower()
for ref in references:
source_lower = ref['source'].lower()
if any(term in diagnosis_lower for term in ['melanoma']) and 'melanoma' in source_lower:
has_relevant_guidelines = True
break
elif 'actinic' in diagnosis_lower and 'actinic' in source_lower:
has_relevant_guidelines = True
break
elif ref.get('score', 0) > 0.7:
has_relevant_guidelines = True
break
if not references or not has_relevant_guidelines:
yield f"[THINKING]No specific published guidelines for {diagnosis}. Using clinical knowledge.[/THINKING]\n"
context = "No specific clinical guidelines available."
references = []
# Build MONET features for context
monet_features = ""
if self.last_monet_result:
top_features = sorted(self.last_monet_result["features"].items(), key=lambda x: x[1], reverse=True)[:5]
monet_features = ", ".join([f"{k.replace('MONET_', '').replace('_', ' ')}: {v:.0%}" for k, v in top_features])
# Stage: Lesion-Specific Management Reasoning
time.sleep(0.3)
yield f"\n[STAGE:management]Generating Lesion-Specific Management Plan[/STAGE]\n"
yield f"[THINKING]Creating management plan tailored to THIS lesion's specific characteristics...[/THINKING]\n"
management_prompt = f"""Generate a CONCISE management plan for this lesion:
DIAGNOSIS: {diagnosis} ({top['probability']:.1%})
{f"Alternative: {runner_up['full_name']} ({runner_up['probability']:.1%})" if runner_up else ""}
KEY FEATURES: {monet_features}
{f"GUIDELINES: {context[:800]}" if context else ""}
Provide:
1. RECOMMENDED ACTION: Biopsy, excision, monitoring, or discharge - with specific reasoning
2. URGENCY: Routine vs urgent vs same-day referral
3. KEY CONCERNS: What features drive this recommendation
Be specific to THIS lesion. 3-5 sentences maximum."""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": self.last_image},
{"type": "text", "text": management_prompt}
]
}
]
# Generate response
start = time.time()
try:
output = self.pipe(messages, max_new_tokens=150)
response = output[0]["generated_text"][-1]["content"]
yield f"[RESPONSE]\n"
words = response.split()
for i, word in enumerate(words):
time.sleep(0.015)
yield word + (" " if i < len(words) - 1 else "")
yield f"\n[/RESPONSE]\n"
except Exception as e:
yield f"[ERROR]Management generation failed: {e}[/ERROR]\n"
# Output references (pre-formatted by MCP server)
if references:
yield rag_data["references_display"]
yield f"\n[COMPLETE]Lesion-specific management plan generated in {time.time() - start:.1f}s[/COMPLETE]\n"
# Store response for recommendation extraction
self.last_management_response = response
def extract_recommendation(self) -> Generator[str, None, Dict[str, Any]]:
"""
Extract structured recommendation from management guidance.
Determines: BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE
For BIOPSY/EXCISION, gets coordinates from MedGemma.
"""
if not self.last_management_response or not self.last_image:
yield "[ERROR]No management guidance available[/ERROR]\n"
return {"action": "UNKNOWN"}
yield f"\n[STAGE:recommendation]Extracting Clinical Recommendation[/STAGE]\n"
# Ask MedGemma to classify the recommendation
classification_prompt = f"""Based on the management plan you just provided:
{self.last_management_response[:1000]}
Classify the PRIMARY recommended action into exactly ONE of these categories:
- BIOPSY: If punch biopsy, shave biopsy, or incisional biopsy is recommended
- EXCISION: If complete surgical excision is recommended
- FOLLOWUP: If monitoring with repeat photography/dermoscopy is recommended
- DISCHARGE: If the lesion is clearly benign and no follow-up needed
Respond with ONLY the category name (BIOPSY, EXCISION, FOLLOWUP, or DISCHARGE) on the first line.
Then on the second line, provide a brief (1 sentence) justification."""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": self.last_image},
{"type": "text", "text": classification_prompt}
]
}
]
try:
output = self.pipe(messages, max_new_tokens=100)
response = output[0]["generated_text"][-1]["content"].strip()
lines = response.split('\n')
action = lines[0].strip().upper()
justification = lines[1].strip() if len(lines) > 1 else ""
# Validate action
valid_actions = ["BIOPSY", "EXCISION", "FOLLOWUP", "DISCHARGE"]
if action not in valid_actions:
# Try to extract from response
for valid in valid_actions:
if valid in response.upper():
action = valid
break
else:
action = "FOLLOWUP" # Default to safe option
yield f"[RESULT]Recommended Action: {action}[/RESULT]\n"
yield f"[OBSERVATION]{justification}[/OBSERVATION]\n"
result = {
"action": action,
"justification": justification
}
return result
except Exception as e:
yield f"[ERROR]Failed to extract recommendation: {e}[/ERROR]\n"
return {"action": "UNKNOWN", "error": str(e)}
def compare_followup_images(
self,
previous_image_path: str,
current_image_path: str
) -> Generator[str, None, None]:
"""
Compare a follow-up image with the previous one.
Runs full analysis pipeline on current image, then compares findings.
"""
yield f"\n[STAGE:comparison]Follow-up Comparison Analysis[/STAGE]\n"
try:
current_image = Image.open(current_image_path).convert("RGB")
except Exception as e:
yield f"[ERROR]Failed to load images: {e}[/ERROR]\n"
return
# Store previous analysis state
prev_exam = self.last_medgemma_exam
# Generate comparison image and MONET deltas via MCP
yield f"\n[STAGE:current_analysis]Analyzing Current Image[/STAGE]\n"
if self.tools_loaded:
try:
compare_data = self.mcp_client.call_tool_sync(
"compare_images",
{
"image1_path": previous_image_path,
"image2_path": current_image_path,
},
)
yield f"[COMPARISON_IMAGE:{compare_data['comparison_path']}]\n"
# Side-by-side GradCAM comparison if both paths available
prev_gc = compare_data.get("prev_gradcam_path")
curr_gc = compare_data.get("curr_gradcam_path")
if prev_gc and curr_gc:
yield f"[GRADCAM_COMPARE:{prev_gc}:{curr_gc}]\n"
# Display MONET feature deltas
if compare_data["monet_deltas"]:
yield f"[TOOL_OUTPUT:Feature Comparison]\n"
for name, delta_info in compare_data["monet_deltas"].items():
prev_val = delta_info["previous"]
curr_val = delta_info["current"]
diff = delta_info["delta"]
short_name = name.replace("MONET_", "").replace("_", " ").title()
direction = "↑" if diff > 0 else "↓"
yield f" {short_name}: {prev_val:.0%}{curr_val:.0%} ({direction}{abs(diff):.0%})\n"
yield f"[/TOOL_OUTPUT]\n"
except Exception as e:
yield f"[ERROR]MCP comparison failed: {e}[/ERROR]\n"
# MedGemma comparison analysis
comparison_prompt = f"""You are comparing TWO images of the same skin lesion taken at different times.
PREVIOUS ANALYSIS:
{prev_exam.get('synthesis', 'Not available')[:500] if prev_exam else 'Not available'}
Now examine the CURRENT image and compare to your memory of the previous findings.
Assess for changes in:
1. SIZE: Has the lesion grown, shrunk, or stayed the same?
2. COLOR: Any new colors appeared? Any colors faded?
3. SHAPE/SYMMETRY: Has the shape changed? More or less symmetric?
4. BORDERS: Sharper, more irregular, or unchanged?
5. STRUCTURES: New dermoscopic structures? Lost structures?
Provide your assessment:
- CHANGE_LEVEL: SIGNIFICANT_CHANGE / MINOR_CHANGE / STABLE / IMPROVED
- Specific changes observed
- Clinical recommendation based on changes"""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": current_image},
{"type": "text", "text": comparison_prompt}
]
}
]
try:
yield f"[THINKING]Comparing current image to previous findings...[/THINKING]\n"
output = self.pipe(messages, max_new_tokens=200)
comparison_result = output[0]["generated_text"][-1]["content"]
yield f"[RESPONSE]\n"
words = comparison_result.split()
for i, word in enumerate(words):
time.sleep(0.02)
yield word + (" " if i < len(words) - 1 else "")
yield f"\n[/RESPONSE]\n"
# Extract change level
change_level = "UNKNOWN"
for level in ["SIGNIFICANT_CHANGE", "MINOR_CHANGE", "STABLE", "IMPROVED"]:
if level in comparison_result.upper():
change_level = level
break
if change_level == "SIGNIFICANT_CHANGE":
yield f"[RESULT]⚠️ SIGNIFICANT CHANGES DETECTED - Further evaluation recommended[/RESULT]\n"
elif change_level == "IMPROVED":
yield f"[RESULT]✓ LESION IMPROVED - Continue monitoring[/RESULT]\n"
elif change_level == "STABLE":
yield f"[RESULT]✓ LESION STABLE - Continue scheduled follow-up[/RESULT]\n"
else:
yield f"[RESULT]Minor changes noted - Clinical correlation recommended[/RESULT]\n"
except Exception as e:
yield f"[ERROR]Comparison analysis failed: {e}[/ERROR]\n"
yield f"\n[COMPLETE]Follow-up comparison complete[/COMPLETE]\n"
def chat(self, message: str, image_path: Optional[str] = None) -> str:
"""Simple chat interface"""
if not self.loaded:
self.load_model()
content = []
if image_path:
image = Image.open(image_path).convert("RGB")
content.append({"type": "image", "image": image})
content.append({"type": "text", "text": message})
messages = [{"role": "user", "content": content}]
output = self.pipe(messages, max_new_tokens=200)
return output[0]["generated_text"][-1]["content"]
def chat_followup(self, message: str) -> Generator[str, None, None]:
"""
Handle follow-up questions using the stored analysis context.
Uses the last analyzed image and diagnosis to provide contextual responses.
"""
if not self.loaded:
yield "[ERROR]Model not loaded[/ERROR]\n"
return
if not self.last_diagnosis or not self.last_image:
yield "[ERROR]No previous analysis context. Please analyze an image first.[/ERROR]\n"
return
# Build context from previous analysis
top_diagnosis = self.last_diagnosis['predictions'][0]
differentials = ", ".join([
f"{p['class']} ({p['probability']:.0%})"
for p in self.last_diagnosis['predictions'][:3]
])
monet_desc = ""
if self.last_monet_result:
monet_desc = self.last_monet_result.get('description', '')
context_prompt = f"""You are a dermatology assistant helping with skin lesion analysis.
PREVIOUS ANALYSIS CONTEXT:
- Primary diagnosis: {top_diagnosis['full_name']} ({top_diagnosis['probability']:.1%} confidence)
- Differential diagnoses: {differentials}
- Visual features: {monet_desc}
The user has a follow-up question about this lesion. Please provide a helpful, medically accurate response.
USER QUESTION: {message}
Provide a concise, informative response. If the question is outside your expertise or requires in-person examination, say so."""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": self.last_image},
{"type": "text", "text": context_prompt}
]
}
]
try:
yield f"[THINKING]Considering your question in context of the previous analysis...[/THINKING]\n"
time.sleep(0.2)
output = self.pipe(messages, max_new_tokens=200)
response = output[0]["generated_text"][-1]["content"]
yield f"[RESPONSE]\n"
# Stream word by word for typewriter effect
words = response.split()
for i, word in enumerate(words):
time.sleep(0.02)
yield word + (" " if i < len(words) - 1 else "")
yield f"\n[/RESPONSE]\n"
except Exception as e:
yield f"[ERROR]Failed to generate response: {e}[/ERROR]\n"
def main():
"""Interactive terminal interface"""
print("=" * 60)
print(" MedGemma Agent - Medical Image Analysis")
print("=" * 60)
agent = MedGemmaAgent(verbose=True)
agent.load_model()
print("\nCommands: analyze <path>, chat <message>, quit")
while True:
try:
user_input = input("\n> ").strip()
if not user_input:
continue
if user_input.lower() in ["quit", "exit", "q"]:
break
parts = user_input.split(maxsplit=1)
cmd = parts[0].lower()
if cmd == "analyze" and len(parts) > 1:
for chunk in agent.analyze_image_stream(parts[1].strip()):
print(chunk, end="", flush=True)
elif cmd == "chat" and len(parts) > 1:
print(agent.chat(parts[1]))
else:
print("Unknown command")
except KeyboardInterrupt:
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()