medical-report-analyzer / document_classifier.py
snikhilesh's picture
Deploy backend with monitoring infrastructure - Complete Medical AI Platform
13d5ab4 verified
raw
history blame
12.9 kB
"""
Document Classifier - Layer 1: Medical Document Classification with Real AI Models
Routes documents to appropriate specialized models using Bio_ClinicalBERT
"""
import logging
from typing import Dict, List, Any, Optional
import re
from model_loader import get_model_loader
logger = logging.getLogger(__name__)
class DocumentClassifier:
"""
Classifies medical documents into types for intelligent routing
Supported document types:
- Radiology Report
- Pathology Report
- Laboratory Results
- Clinical Notes
- Discharge Summary
- ECG/Cardiology Report
- Operative Note
- Medication List
- Consultation Note
"""
def __init__(self):
self.model_loader = get_model_loader()
self.document_types = [
"radiology",
"pathology",
"laboratory",
"clinical_notes",
"discharge_summary",
"cardiology",
"operative_note",
"medication_list",
"consultation",
"unknown"
]
# Keywords for document type detection (fallback method)
self.classification_keywords = {
"radiology": [
"ct scan", "mri", "x-ray", "radiograph", "ultrasound",
"imaging", "radiology", "chest xray", "chest x-ray",
"ct", "pet scan", "mammogram", "fluoroscopy"
],
"pathology": [
"pathology", "biopsy", "histopathology", "cytology",
"tissue", "slide", "specimen", "microscopic",
"immunohistochemistry", "tumor grade", "malignant"
],
"laboratory": [
"lab results", "laboratory", "complete blood count", "cbc",
"chemistry panel", "metabolic panel", "lipid panel",
"glucose", "hemoglobin", "platelet", "wbc", "rbc",
"test results", "reference range"
],
"cardiology": [
"ecg", "ekg", "electrocardiogram", "echo", "echocardiogram",
"stress test", "cardiac", "heart", "arrhythmia",
"ejection fraction", "coronary", "myocardial"
],
"discharge_summary": [
"discharge summary", "discharge diagnosis", "hospital course",
"admission date", "discharge date", "discharge medications",
"discharge instructions", "follow-up"
],
"operative_note": [
"operative note", "operation", "surgery", "surgical procedure",
"procedure performed", "anesthesia", "incision", "operative findings",
"post-operative", "surgeon"
],
"medication_list": [
"medication list", "current medications", "prescriptions",
"drug list", "rx", "dosage", "frequency"
],
"consultation": [
"consultation", "consulted", "specialist", "referred",
"opinion", "evaluation", "assessment and plan"
]
}
logger.info("Document Classifier initialized")
async def classify(self, pdf_content: Dict[str, Any]) -> Dict[str, Any]:
"""
Classify medical document using AI model + keyword fallback
Returns:
Classification result with:
- document_type: primary classification
- confidence: confidence score
- secondary_types: other possible classifications
- routing_hints: suggestions for model routing
"""
try:
text = pdf_content.get("text", "")
metadata = pdf_content.get("metadata", {})
sections = pdf_content.get("sections", {})
# Try AI-based classification first
ai_result = await self._ai_classification(text[:1000]) # Use first 1000 chars
# Also run keyword-based classification as backup
keyword_result = self._keyword_classification(text.lower())
# Combine results with AI taking precedence if confidence is high
if ai_result.get("confidence", 0) > 0.6:
primary_type = ai_result["document_type"]
confidence = ai_result["confidence"]
method = "ai_model"
else:
primary_type = keyword_result["document_type"]
confidence = keyword_result["confidence"]
method = "keyword_based"
# Get secondary types from both methods
secondary_types = list(set(
ai_result.get("secondary_types", []) +
keyword_result.get("secondary_types", [])
))[:3]
# Generate routing hints based on classification
routing_hints = self._generate_routing_hints(
primary_type,
secondary_types,
pdf_content
)
result = {
"document_type": primary_type,
"confidence": confidence,
"secondary_types": secondary_types,
"routing_hints": routing_hints,
"classification_method": method,
"ai_confidence": ai_result.get("confidence", 0),
"keyword_confidence": keyword_result.get("confidence", 0)
}
logger.info(f"Document classified as: {primary_type} (confidence: {confidence:.2f}, method: {method})")
return result
except Exception as e:
logger.error(f"Classification failed: {str(e)}")
return {
"document_type": "unknown",
"confidence": 0.0,
"secondary_types": [],
"routing_hints": {"models": ["general"]},
"error": str(e)
}
async def _ai_classification(self, text: str) -> Dict[str, Any]:
"""Use Bio_ClinicalBERT for document classification"""
try:
# Use model loader for classification
import asyncio
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: self.model_loader.run_inference(
"document_classifier",
text,
{}
)
)
if result.get("success") and result.get("result"):
model_output = result["result"]
# Handle different output formats
if isinstance(model_output, list) and len(model_output) > 0:
top_prediction = model_output[0]
# Map model labels to our document types
label = top_prediction.get("label", "").lower()
score = top_prediction.get("score", 0.5)
# Map common labels to document types
label_mapping = {
"radiology": "radiology",
"pathology": "pathology",
"laboratory": "laboratory",
"lab": "laboratory",
"cardiology": "cardiology",
"clinical": "clinical_notes",
"discharge": "discharge_summary",
"operative": "operative_note",
"surgery": "operative_note",
"medication": "medication_list",
"consultation": "consultation"
}
doc_type = "unknown"
for key, value in label_mapping.items():
if key in label:
doc_type = value
break
# Get secondary types from other predictions
secondary_types = []
for pred in model_output[1:4]:
sec_label = pred.get("label", "").lower()
for key, value in label_mapping.items():
if key in sec_label and value != doc_type:
secondary_types.append(value)
break
return {
"document_type": doc_type,
"confidence": score,
"secondary_types": secondary_types
}
# Fallback if model doesn't return expected format
return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []}
except Exception as e:
logger.warning(f"AI classification failed: {str(e)}, falling back to keywords")
return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []}
def _keyword_classification(self, text: str) -> Dict[str, Any]:
"""Keyword-based classification as fallback"""
# Score each document type
scores = {}
for doc_type, keywords in self.classification_keywords.items():
score = self._calculate_type_score(text, keywords)
scores[doc_type] = score
# Get top classifications
sorted_types = sorted(scores.items(), key=lambda x: x[1], reverse=True)
primary_type = sorted_types[0][0] if sorted_types else "unknown"
primary_score = sorted_types[0][1] if sorted_types else 0.0
# Confidence calculation
confidence = min(primary_score / 10.0, 1.0) # Normalize to 0-1
# Secondary types (score > 3)
secondary_types = [
doc_type for doc_type, score in sorted_types[1:4]
if score > 3
]
return {
"document_type": primary_type,
"confidence": confidence,
"secondary_types": secondary_types
}
def _calculate_type_score(self, text: str, keywords: List[str]) -> float:
"""Calculate relevance score for a document type"""
score = 0.0
for keyword in keywords:
# Count occurrences (weighted by keyword importance)
count = text.count(keyword.lower())
# Keyword at beginning of document = higher weight
if keyword.lower() in text[:500]:
score += count * 2
else:
score += count
return score
def _generate_routing_hints(
self,
primary_type: str,
secondary_types: List[str],
pdf_content: Dict[str, Any]
) -> Dict[str, Any]:
"""
Generate hints for intelligent model routing
"""
hints = {
"primary_models": [],
"secondary_models": [],
"extract_images": False,
"extract_tables": False,
"priority": "standard"
}
# Map document types to model domains
type_to_models = {
"radiology": ["radiology_vqa", "report_generation", "segmentation"],
"pathology": ["pathology_classification", "slide_analysis"],
"laboratory": ["lab_normalization", "result_interpretation"],
"cardiology": ["ecg_analysis", "cardiac_imaging"],
"discharge_summary": ["clinical_summarization", "coding_extraction"],
"operative_note": ["procedure_extraction", "coding"],
"clinical_notes": ["clinical_ner", "summarization"],
"consultation": ["clinical_ner", "diagnosis_extraction"],
"medication_list": ["medication_extraction", "drug_interaction"]
}
# Set primary models
hints["primary_models"] = type_to_models.get(primary_type, ["general"])
# Set secondary models
for sec_type in secondary_types:
if sec_type in type_to_models:
hints["secondary_models"].extend(type_to_models[sec_type])
# Special processing hints
if primary_type == "radiology":
hints["extract_images"] = True
hints["priority"] = "high"
if primary_type == "laboratory":
hints["extract_tables"] = True
if primary_type == "pathology":
hints["extract_images"] = True
# Check if document has images
if pdf_content.get("images"):
hints["has_images"] = True
# Check if document has tables
if pdf_content.get("tables"):
hints["has_tables"] = True
return hints