|
|
""" |
|
|
Document Classifier - Layer 1: Medical Document Classification with Real AI Models |
|
|
Routes documents to appropriate specialized models using Bio_ClinicalBERT |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Dict, List, Any, Optional |
|
|
import re |
|
|
from model_loader import get_model_loader |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DocumentClassifier: |
|
|
""" |
|
|
Classifies medical documents into types for intelligent routing |
|
|
|
|
|
Supported document types: |
|
|
- Radiology Report |
|
|
- Pathology Report |
|
|
- Laboratory Results |
|
|
- Clinical Notes |
|
|
- Discharge Summary |
|
|
- ECG/Cardiology Report |
|
|
- Operative Note |
|
|
- Medication List |
|
|
- Consultation Note |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model_loader = get_model_loader() |
|
|
self.document_types = [ |
|
|
"radiology", |
|
|
"pathology", |
|
|
"laboratory", |
|
|
"clinical_notes", |
|
|
"discharge_summary", |
|
|
"cardiology", |
|
|
"operative_note", |
|
|
"medication_list", |
|
|
"consultation", |
|
|
"unknown" |
|
|
] |
|
|
|
|
|
|
|
|
self.classification_keywords = { |
|
|
"radiology": [ |
|
|
"ct scan", "mri", "x-ray", "radiograph", "ultrasound", |
|
|
"imaging", "radiology", "chest xray", "chest x-ray", |
|
|
"ct", "pet scan", "mammogram", "fluoroscopy" |
|
|
], |
|
|
"pathology": [ |
|
|
"pathology", "biopsy", "histopathology", "cytology", |
|
|
"tissue", "slide", "specimen", "microscopic", |
|
|
"immunohistochemistry", "tumor grade", "malignant" |
|
|
], |
|
|
"laboratory": [ |
|
|
"lab results", "laboratory", "complete blood count", "cbc", |
|
|
"chemistry panel", "metabolic panel", "lipid panel", |
|
|
"glucose", "hemoglobin", "platelet", "wbc", "rbc", |
|
|
"test results", "reference range" |
|
|
], |
|
|
"cardiology": [ |
|
|
"ecg", "ekg", "electrocardiogram", "echo", "echocardiogram", |
|
|
"stress test", "cardiac", "heart", "arrhythmia", |
|
|
"ejection fraction", "coronary", "myocardial" |
|
|
], |
|
|
"discharge_summary": [ |
|
|
"discharge summary", "discharge diagnosis", "hospital course", |
|
|
"admission date", "discharge date", "discharge medications", |
|
|
"discharge instructions", "follow-up" |
|
|
], |
|
|
"operative_note": [ |
|
|
"operative note", "operation", "surgery", "surgical procedure", |
|
|
"procedure performed", "anesthesia", "incision", "operative findings", |
|
|
"post-operative", "surgeon" |
|
|
], |
|
|
"medication_list": [ |
|
|
"medication list", "current medications", "prescriptions", |
|
|
"drug list", "rx", "dosage", "frequency" |
|
|
], |
|
|
"consultation": [ |
|
|
"consultation", "consulted", "specialist", "referred", |
|
|
"opinion", "evaluation", "assessment and plan" |
|
|
] |
|
|
} |
|
|
|
|
|
logger.info("Document Classifier initialized") |
|
|
|
|
|
async def classify(self, pdf_content: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Classify medical document using AI model + keyword fallback |
|
|
|
|
|
Returns: |
|
|
Classification result with: |
|
|
- document_type: primary classification |
|
|
- confidence: confidence score |
|
|
- secondary_types: other possible classifications |
|
|
- routing_hints: suggestions for model routing |
|
|
""" |
|
|
try: |
|
|
text = pdf_content.get("text", "") |
|
|
metadata = pdf_content.get("metadata", {}) |
|
|
sections = pdf_content.get("sections", {}) |
|
|
|
|
|
|
|
|
ai_result = await self._ai_classification(text[:1000]) |
|
|
|
|
|
|
|
|
keyword_result = self._keyword_classification(text.lower()) |
|
|
|
|
|
|
|
|
if ai_result.get("confidence", 0) > 0.6: |
|
|
primary_type = ai_result["document_type"] |
|
|
confidence = ai_result["confidence"] |
|
|
method = "ai_model" |
|
|
else: |
|
|
primary_type = keyword_result["document_type"] |
|
|
confidence = keyword_result["confidence"] |
|
|
method = "keyword_based" |
|
|
|
|
|
|
|
|
secondary_types = list(set( |
|
|
ai_result.get("secondary_types", []) + |
|
|
keyword_result.get("secondary_types", []) |
|
|
))[:3] |
|
|
|
|
|
|
|
|
routing_hints = self._generate_routing_hints( |
|
|
primary_type, |
|
|
secondary_types, |
|
|
pdf_content |
|
|
) |
|
|
|
|
|
result = { |
|
|
"document_type": primary_type, |
|
|
"confidence": confidence, |
|
|
"secondary_types": secondary_types, |
|
|
"routing_hints": routing_hints, |
|
|
"classification_method": method, |
|
|
"ai_confidence": ai_result.get("confidence", 0), |
|
|
"keyword_confidence": keyword_result.get("confidence", 0) |
|
|
} |
|
|
|
|
|
logger.info(f"Document classified as: {primary_type} (confidence: {confidence:.2f}, method: {method})") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Classification failed: {str(e)}") |
|
|
return { |
|
|
"document_type": "unknown", |
|
|
"confidence": 0.0, |
|
|
"secondary_types": [], |
|
|
"routing_hints": {"models": ["general"]}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
async def _ai_classification(self, text: str) -> Dict[str, Any]: |
|
|
"""Use Bio_ClinicalBERT for document classification""" |
|
|
try: |
|
|
|
|
|
import asyncio |
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
result = await loop.run_in_executor( |
|
|
None, |
|
|
lambda: self.model_loader.run_inference( |
|
|
"document_classifier", |
|
|
text, |
|
|
{} |
|
|
) |
|
|
) |
|
|
|
|
|
if result.get("success") and result.get("result"): |
|
|
model_output = result["result"] |
|
|
|
|
|
|
|
|
if isinstance(model_output, list) and len(model_output) > 0: |
|
|
top_prediction = model_output[0] |
|
|
|
|
|
|
|
|
label = top_prediction.get("label", "").lower() |
|
|
score = top_prediction.get("score", 0.5) |
|
|
|
|
|
|
|
|
label_mapping = { |
|
|
"radiology": "radiology", |
|
|
"pathology": "pathology", |
|
|
"laboratory": "laboratory", |
|
|
"lab": "laboratory", |
|
|
"cardiology": "cardiology", |
|
|
"clinical": "clinical_notes", |
|
|
"discharge": "discharge_summary", |
|
|
"operative": "operative_note", |
|
|
"surgery": "operative_note", |
|
|
"medication": "medication_list", |
|
|
"consultation": "consultation" |
|
|
} |
|
|
|
|
|
doc_type = "unknown" |
|
|
for key, value in label_mapping.items(): |
|
|
if key in label: |
|
|
doc_type = value |
|
|
break |
|
|
|
|
|
|
|
|
secondary_types = [] |
|
|
for pred in model_output[1:4]: |
|
|
sec_label = pred.get("label", "").lower() |
|
|
for key, value in label_mapping.items(): |
|
|
if key in sec_label and value != doc_type: |
|
|
secondary_types.append(value) |
|
|
break |
|
|
|
|
|
return { |
|
|
"document_type": doc_type, |
|
|
"confidence": score, |
|
|
"secondary_types": secondary_types |
|
|
} |
|
|
|
|
|
|
|
|
return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []} |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"AI classification failed: {str(e)}, falling back to keywords") |
|
|
return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []} |
|
|
|
|
|
def _keyword_classification(self, text: str) -> Dict[str, Any]: |
|
|
"""Keyword-based classification as fallback""" |
|
|
|
|
|
scores = {} |
|
|
for doc_type, keywords in self.classification_keywords.items(): |
|
|
score = self._calculate_type_score(text, keywords) |
|
|
scores[doc_type] = score |
|
|
|
|
|
|
|
|
sorted_types = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
primary_type = sorted_types[0][0] if sorted_types else "unknown" |
|
|
primary_score = sorted_types[0][1] if sorted_types else 0.0 |
|
|
|
|
|
|
|
|
confidence = min(primary_score / 10.0, 1.0) |
|
|
|
|
|
|
|
|
secondary_types = [ |
|
|
doc_type for doc_type, score in sorted_types[1:4] |
|
|
if score > 3 |
|
|
] |
|
|
|
|
|
return { |
|
|
"document_type": primary_type, |
|
|
"confidence": confidence, |
|
|
"secondary_types": secondary_types |
|
|
} |
|
|
|
|
|
def _calculate_type_score(self, text: str, keywords: List[str]) -> float: |
|
|
"""Calculate relevance score for a document type""" |
|
|
score = 0.0 |
|
|
|
|
|
for keyword in keywords: |
|
|
|
|
|
count = text.count(keyword.lower()) |
|
|
|
|
|
|
|
|
if keyword.lower() in text[:500]: |
|
|
score += count * 2 |
|
|
else: |
|
|
score += count |
|
|
|
|
|
return score |
|
|
|
|
|
def _generate_routing_hints( |
|
|
self, |
|
|
primary_type: str, |
|
|
secondary_types: List[str], |
|
|
pdf_content: Dict[str, Any] |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate hints for intelligent model routing |
|
|
""" |
|
|
hints = { |
|
|
"primary_models": [], |
|
|
"secondary_models": [], |
|
|
"extract_images": False, |
|
|
"extract_tables": False, |
|
|
"priority": "standard" |
|
|
} |
|
|
|
|
|
|
|
|
type_to_models = { |
|
|
"radiology": ["radiology_vqa", "report_generation", "segmentation"], |
|
|
"pathology": ["pathology_classification", "slide_analysis"], |
|
|
"laboratory": ["lab_normalization", "result_interpretation"], |
|
|
"cardiology": ["ecg_analysis", "cardiac_imaging"], |
|
|
"discharge_summary": ["clinical_summarization", "coding_extraction"], |
|
|
"operative_note": ["procedure_extraction", "coding"], |
|
|
"clinical_notes": ["clinical_ner", "summarization"], |
|
|
"consultation": ["clinical_ner", "diagnosis_extraction"], |
|
|
"medication_list": ["medication_extraction", "drug_interaction"] |
|
|
} |
|
|
|
|
|
|
|
|
hints["primary_models"] = type_to_models.get(primary_type, ["general"]) |
|
|
|
|
|
|
|
|
for sec_type in secondary_types: |
|
|
if sec_type in type_to_models: |
|
|
hints["secondary_models"].extend(type_to_models[sec_type]) |
|
|
|
|
|
|
|
|
if primary_type == "radiology": |
|
|
hints["extract_images"] = True |
|
|
hints["priority"] = "high" |
|
|
|
|
|
if primary_type == "laboratory": |
|
|
hints["extract_tables"] = True |
|
|
|
|
|
if primary_type == "pathology": |
|
|
hints["extract_images"] = True |
|
|
|
|
|
|
|
|
if pdf_content.get("images"): |
|
|
hints["has_images"] = True |
|
|
|
|
|
|
|
|
if pdf_content.get("tables"): |
|
|
hints["has_tables"] = True |
|
|
|
|
|
return hints |
|
|
|