radioflow / agents /report_generator.py
SamarpeetGarad's picture
Upload agents/report_generator.py with huggingface_hub
3a161e0 verified
"""
Agent 3: Report Generator
Uses MedGemma to generate structured radiology reports
"""
import time
from typing import Any, Dict, Optional, List
from datetime import datetime
from .base_agent import BaseAgent, AgentResult
# Import the unified MedGemma engine
try:
from .medgemma_engine import get_engine, MedGemmaEngine
ENGINE_AVAILABLE = True
except ImportError:
ENGINE_AVAILABLE = False
class ReportGeneratorAgent(BaseAgent):
"""
Agent 3: MedGemma Report Generator
Generates structured radiology reports from interpreted findings
using the unified MedGemma engine.
"""
def __init__(self, demo_mode: bool = False):
super().__init__(
name="Report Generator",
model_name="google/medgemma-4b-it"
)
self.demo_mode = demo_mode
self.engine = None
def load_model(self) -> bool:
"""Load MedGemma model via unified engine."""
if self.demo_mode or not ENGINE_AVAILABLE:
self.is_loaded = True
return True
try:
self.engine = get_engine(force_demo=self.demo_mode)
self.is_loaded = self.engine.is_loaded
return True
except Exception as e:
print(f"Failed to load MedGemma engine: {e}")
self.demo_mode = True
self.is_loaded = True
return True
def process(self, input_data: Any, context: Optional[Dict] = None) -> AgentResult:
"""
Generate a structured radiology report.
Args:
input_data: Dictionary from Finding Interpreter agent
context: Patient and study context
Returns:
AgentResult with structured report
"""
start_time = time.time()
if not isinstance(input_data, dict):
return AgentResult(
agent_name=self.name,
status="error",
data={},
processing_time_ms=(time.time() - start_time) * 1000,
error_message="Invalid input: expected dictionary from Finding Interpreter"
)
# Extract data from previous agent
interpreted_findings = input_data.get("interpreted_findings", [])
clinical_summary = input_data.get("clinical_summary", "")
key_concerns = input_data.get("key_concerns", [])
# Process - always try to use real model if available
if self.engine and self.engine.is_loaded and self.engine.backend != "demo":
report = self._run_model_inference(
interpreted_findings, clinical_summary, key_concerns, context
)
else:
report = self._simulate_report_generation(
interpreted_findings, clinical_summary, key_concerns, context
)
processing_time = (time.time() - start_time) * 1000
return AgentResult(
agent_name=self.name,
status="success",
data=report,
processing_time_ms=processing_time
)
def _run_model_inference(
self,
interpreted_findings: List[Dict],
clinical_summary: str,
key_concerns: List[str],
context: Optional[Dict]
) -> Dict:
"""Generate report using MedGemma via unified engine."""
try:
prompt = self._build_report_prompt(
interpreted_findings, clinical_summary, key_concerns, context
)
# Use the unified engine to generate report
report_text = self.engine.generate(prompt, max_tokens=500)
return self._structure_report(report_text, interpreted_findings, context)
except Exception as e:
print(f"Report generation error: {e}")
return self._simulate_report_generation(
interpreted_findings, clinical_summary, key_concerns, context
)
def _simulate_report_generation(
self,
interpreted_findings: List[Dict],
clinical_summary: str,
key_concerns: List[str],
context: Optional[Dict]
) -> Dict:
"""Simulate report generation for demo."""
time.sleep(0.5) # Simulate processing
# Extract context
patient_info = context or {}
indication = patient_info.get("clinical_history", "Chest pain, rule out pneumonia")
comparison = patient_info.get("comparison", "None available")
# Build findings section
findings_text = self._build_findings_section(interpreted_findings)
# Build impression
impression = self._build_impression(interpreted_findings, key_concerns)
# Build recommendations
recommendations = self._build_recommendations(interpreted_findings)
# Assemble full report
report_sections = {
"clinical_indication": indication,
"technique": "Single frontal (PA) view of the chest was obtained.",
"comparison": comparison,
"findings": findings_text,
"impression": impression,
"recommendations": recommendations
}
# Format as full text report
full_report = self._format_full_report(report_sections)
return {
"sections": report_sections,
"full_report": full_report,
"report_timestamp": datetime.now().isoformat(),
"word_count": len(full_report.split()),
"findings_count": len(interpreted_findings),
"model_used": f"{self.model_name} (demo mode)"
}
def _build_report_prompt(
self,
interpreted_findings: List[Dict],
clinical_summary: str,
key_concerns: List[str],
context: Optional[Dict]
) -> str:
"""Build prompt for report generation."""
findings_text = "\n".join([
f"- {f['original'].get('type', 'Finding')}: {f['original'].get('description', '')}"
for f in interpreted_findings
])
context_text = ""
if context:
context_text = f"Clinical History: {context.get('clinical_history', 'Not provided')}"
return f"""Generate a structured chest X-ray radiology report based on the following findings.
**Clinical Information:**
{context_text}
**Findings:**
{findings_text if findings_text else "No significant abnormalities."}
**Key Concerns:**
{', '.join(key_concerns) if key_concerns else "None"}
Generate a complete report with sections: INDICATION, TECHNIQUE, COMPARISON, FINDINGS, IMPRESSION, and RECOMMENDATIONS.
Use professional radiology terminology and standard reporting format."""
def _build_findings_section(self, interpreted_findings: List[Dict]) -> str:
"""Build the findings section of the report."""
if not interpreted_findings:
return """LUNGS: Clear bilaterally. No focal consolidation or pleural effusion. Lungs are well-expanded.
HEART: Normal cardiac silhouette. Cardiothoracic ratio within normal limits.
MEDIASTINUM: Unremarkable. No widening or lymphadenopathy.
BONES: No acute osseous abnormalities identified.
SOFT TISSUES: Unremarkable."""
# Organize findings by region
findings_by_region = {
"LUNGS": [],
"HEART": [],
"PLEURA": [],
"MEDIASTINUM": [],
"BONES": [],
"OTHER": []
}
region_mapping = {
"right_upper_lung": "LUNGS",
"right_middle_lung": "LUNGS",
"right_lower_lung": "LUNGS",
"left_upper_lung": "LUNGS",
"left_lower_lung": "LUNGS",
"cardiac_silhouette": "HEART",
"costophrenic_angles": "PLEURA",
"mediastinum": "MEDIASTINUM",
"bones": "BONES",
"diaphragm": "OTHER"
}
for finding in interpreted_findings:
original = finding.get("original", {})
region = original.get("region", "OTHER")
category = region_mapping.get(region, "OTHER")
description = original.get("description", "Abnormality noted")
findings_by_region[category].append(description)
# Build text
findings_text = []
if findings_by_region["LUNGS"]:
findings_text.append("LUNGS: " + " ".join(findings_by_region["LUNGS"]))
else:
findings_text.append("LUNGS: Clear bilaterally. No focal consolidation. Lungs well-expanded.")
if findings_by_region["HEART"]:
findings_text.append("HEART: " + " ".join(findings_by_region["HEART"]))
else:
findings_text.append("HEART: Normal cardiac silhouette.")
if findings_by_region["PLEURA"]:
findings_text.append("PLEURA: " + " ".join(findings_by_region["PLEURA"]))
else:
findings_text.append("PLEURA: No pleural effusion.")
findings_text.append("MEDIASTINUM: " + (
" ".join(findings_by_region["MEDIASTINUM"]) if findings_by_region["MEDIASTINUM"]
else "Unremarkable."
))
findings_text.append("BONES: " + (
" ".join(findings_by_region["BONES"]) if findings_by_region["BONES"]
else "No acute osseous abnormalities."
))
return "\n".join(findings_text)
def _build_impression(self, interpreted_findings: List[Dict], key_concerns: List[str]) -> str:
"""Build the impression section."""
if not interpreted_findings:
return "1. No acute cardiopulmonary abnormality."
impressions = []
for i, finding in enumerate(interpreted_findings, 1):
original = finding.get("original", {})
finding_type = original.get("type", "Finding").capitalize()
region = original.get("region", "").replace("_", " ")
severity = original.get("severity", "").capitalize()
impression = f"{i}. {severity} {finding_type}"
if region:
impression += f" involving the {region}"
impression += "."
impressions.append(impression)
return "\n".join(impressions)
def _build_recommendations(self, interpreted_findings: List[Dict]) -> str:
"""Build recommendations section."""
if not interpreted_findings:
return "No specific follow-up recommended."
recommendations = []
for finding in interpreted_findings:
followup = finding.get("recommended_followup", "")
if followup and followup not in recommendations:
recommendations.append(followup)
if not recommendations:
recommendations.append("Clinical correlation recommended.")
return " ".join(recommendations)
def _format_full_report(self, sections: Dict) -> str:
"""Format sections into a complete report."""
report = f"""
================================================================================
CHEST RADIOGRAPH REPORT
================================================================================
CLINICAL INDICATION:
{sections['clinical_indication']}
TECHNIQUE:
{sections['technique']}
COMPARISON:
{sections['comparison']}
FINDINGS:
{sections['findings']}
IMPRESSION:
{sections['impression']}
RECOMMENDATIONS:
{sections['recommendations']}
================================================================================
Report generated by RadioFlow AI System
This AI-generated report requires radiologist verification before clinical use.
================================================================================
"""
return report.strip()
def _structure_report(self, report_text: str, interpreted_findings: List[Dict], context: Optional[Dict]) -> Dict:
"""Structure the model-generated report."""
return {
"sections": {
"full_text": report_text
},
"full_report": report_text,
"report_timestamp": datetime.now().isoformat(),
"word_count": len(report_text.split()),
"findings_count": len(interpreted_findings),
"model_used": self.model_name
}