Spaces:
Sleeping
Sleeping
| """ | |
| Document Classifier - Layer 1: Medical Document Classification with Real AI Models | |
| Routes documents to appropriate specialized models using Bio_ClinicalBERT | |
| """ | |
| import logging | |
| from typing import Dict, List, Any, Optional | |
| import re | |
| from model_loader import get_model_loader | |
| logger = logging.getLogger(__name__) | |
| class DocumentClassifier: | |
| """ | |
| Classifies medical documents into types for intelligent routing | |
| Supported document types: | |
| - Radiology Report | |
| - Pathology Report | |
| - Laboratory Results | |
| - Clinical Notes | |
| - Discharge Summary | |
| - ECG/Cardiology Report | |
| - Operative Note | |
| - Medication List | |
| - Consultation Note | |
| """ | |
| def __init__(self): | |
| self.model_loader = get_model_loader() | |
| self.document_types = [ | |
| "radiology", | |
| "pathology", | |
| "laboratory", | |
| "clinical_notes", | |
| "discharge_summary", | |
| "cardiology", | |
| "operative_note", | |
| "medication_list", | |
| "consultation", | |
| "unknown" | |
| ] | |
| # Keywords for document type detection (fallback method) | |
| self.classification_keywords = { | |
| "radiology": [ | |
| "ct scan", "mri", "x-ray", "radiograph", "ultrasound", | |
| "imaging", "radiology", "chest xray", "chest x-ray", | |
| "ct", "pet scan", "mammogram", "fluoroscopy" | |
| ], | |
| "pathology": [ | |
| "pathology", "biopsy", "histopathology", "cytology", | |
| "tissue", "slide", "specimen", "microscopic", | |
| "immunohistochemistry", "tumor grade", "malignant" | |
| ], | |
| "laboratory": [ | |
| "lab results", "laboratory", "complete blood count", "cbc", | |
| "chemistry panel", "metabolic panel", "lipid panel", | |
| "glucose", "hemoglobin", "platelet", "wbc", "rbc", | |
| "test results", "reference range" | |
| ], | |
| "cardiology": [ | |
| "ecg", "ekg", "electrocardiogram", "echo", "echocardiogram", | |
| "stress test", "cardiac", "heart", "arrhythmia", | |
| "ejection fraction", "coronary", "myocardial" | |
| ], | |
| "discharge_summary": [ | |
| "discharge summary", "discharge diagnosis", "hospital course", | |
| "admission date", "discharge date", "discharge medications", | |
| "discharge instructions", "follow-up" | |
| ], | |
| "operative_note": [ | |
| "operative note", "operation", "surgery", "surgical procedure", | |
| "procedure performed", "anesthesia", "incision", "operative findings", | |
| "post-operative", "surgeon" | |
| ], | |
| "medication_list": [ | |
| "medication list", "current medications", "prescriptions", | |
| "drug list", "rx", "dosage", "frequency" | |
| ], | |
| "consultation": [ | |
| "consultation", "consulted", "specialist", "referred", | |
| "opinion", "evaluation", "assessment and plan" | |
| ] | |
| } | |
| logger.info("Document Classifier initialized") | |
| async def classify(self, pdf_content: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Classify medical document using AI model + keyword fallback | |
| Returns: | |
| Classification result with: | |
| - document_type: primary classification | |
| - confidence: confidence score | |
| - secondary_types: other possible classifications | |
| - routing_hints: suggestions for model routing | |
| """ | |
| try: | |
| text = pdf_content.get("text", "") | |
| metadata = pdf_content.get("metadata", {}) | |
| sections = pdf_content.get("sections", {}) | |
| # Try AI-based classification first | |
| ai_result = await self._ai_classification(text[:1000]) # Use first 1000 chars | |
| # Also run keyword-based classification as backup | |
| keyword_result = self._keyword_classification(text.lower()) | |
| # Combine results with AI taking precedence if confidence is high | |
| if ai_result.get("confidence", 0) > 0.6: | |
| primary_type = ai_result["document_type"] | |
| confidence = ai_result["confidence"] | |
| method = "ai_model" | |
| else: | |
| primary_type = keyword_result["document_type"] | |
| confidence = keyword_result["confidence"] | |
| method = "keyword_based" | |
| # Get secondary types from both methods | |
| secondary_types = list(set( | |
| ai_result.get("secondary_types", []) + | |
| keyword_result.get("secondary_types", []) | |
| ))[:3] | |
| # Generate routing hints based on classification | |
| routing_hints = self._generate_routing_hints( | |
| primary_type, | |
| secondary_types, | |
| pdf_content | |
| ) | |
| result = { | |
| "document_type": primary_type, | |
| "confidence": confidence, | |
| "secondary_types": secondary_types, | |
| "routing_hints": routing_hints, | |
| "classification_method": method, | |
| "ai_confidence": ai_result.get("confidence", 0), | |
| "keyword_confidence": keyword_result.get("confidence", 0) | |
| } | |
| logger.info(f"Document classified as: {primary_type} (confidence: {confidence:.2f}, method: {method})") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Classification failed: {str(e)}") | |
| return { | |
| "document_type": "unknown", | |
| "confidence": 0.0, | |
| "secondary_types": [], | |
| "routing_hints": {"models": ["general"]}, | |
| "error": str(e) | |
| } | |
| async def _ai_classification(self, text: str) -> Dict[str, Any]: | |
| """Use Bio_ClinicalBERT for document classification""" | |
| try: | |
| # Use model loader for classification | |
| import asyncio | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor( | |
| None, | |
| lambda: self.model_loader.run_inference( | |
| "document_classifier", | |
| text, | |
| {} | |
| ) | |
| ) | |
| if result.get("success") and result.get("result"): | |
| model_output = result["result"] | |
| # Handle different output formats | |
| if isinstance(model_output, list) and len(model_output) > 0: | |
| top_prediction = model_output[0] | |
| # Map model labels to our document types | |
| label = top_prediction.get("label", "").lower() | |
| score = top_prediction.get("score", 0.5) | |
| # Map common labels to document types | |
| label_mapping = { | |
| "radiology": "radiology", | |
| "pathology": "pathology", | |
| "laboratory": "laboratory", | |
| "lab": "laboratory", | |
| "cardiology": "cardiology", | |
| "clinical": "clinical_notes", | |
| "discharge": "discharge_summary", | |
| "operative": "operative_note", | |
| "surgery": "operative_note", | |
| "medication": "medication_list", | |
| "consultation": "consultation" | |
| } | |
| doc_type = "unknown" | |
| for key, value in label_mapping.items(): | |
| if key in label: | |
| doc_type = value | |
| break | |
| # Get secondary types from other predictions | |
| secondary_types = [] | |
| for pred in model_output[1:4]: | |
| sec_label = pred.get("label", "").lower() | |
| for key, value in label_mapping.items(): | |
| if key in sec_label and value != doc_type: | |
| secondary_types.append(value) | |
| break | |
| return { | |
| "document_type": doc_type, | |
| "confidence": score, | |
| "secondary_types": secondary_types | |
| } | |
| # Fallback if model doesn't return expected format | |
| return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []} | |
| except Exception as e: | |
| logger.warning(f"AI classification failed: {str(e)}, falling back to keywords") | |
| return {"document_type": "unknown", "confidence": 0.0, "secondary_types": []} | |
| def _keyword_classification(self, text: str) -> Dict[str, Any]: | |
| """Keyword-based classification as fallback""" | |
| # Score each document type | |
| scores = {} | |
| for doc_type, keywords in self.classification_keywords.items(): | |
| score = self._calculate_type_score(text, keywords) | |
| scores[doc_type] = score | |
| # Get top classifications | |
| sorted_types = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
| primary_type = sorted_types[0][0] if sorted_types else "unknown" | |
| primary_score = sorted_types[0][1] if sorted_types else 0.0 | |
| # Confidence calculation | |
| confidence = min(primary_score / 10.0, 1.0) # Normalize to 0-1 | |
| # Secondary types (score > 3) | |
| secondary_types = [ | |
| doc_type for doc_type, score in sorted_types[1:4] | |
| if score > 3 | |
| ] | |
| return { | |
| "document_type": primary_type, | |
| "confidence": confidence, | |
| "secondary_types": secondary_types | |
| } | |
| def _calculate_type_score(self, text: str, keywords: List[str]) -> float: | |
| """Calculate relevance score for a document type""" | |
| score = 0.0 | |
| for keyword in keywords: | |
| # Count occurrences (weighted by keyword importance) | |
| count = text.count(keyword.lower()) | |
| # Keyword at beginning of document = higher weight | |
| if keyword.lower() in text[:500]: | |
| score += count * 2 | |
| else: | |
| score += count | |
| return score | |
| def _generate_routing_hints( | |
| self, | |
| primary_type: str, | |
| secondary_types: List[str], | |
| pdf_content: Dict[str, Any] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate hints for intelligent model routing | |
| """ | |
| hints = { | |
| "primary_models": [], | |
| "secondary_models": [], | |
| "extract_images": False, | |
| "extract_tables": False, | |
| "priority": "standard" | |
| } | |
| # Map document types to model domains | |
| type_to_models = { | |
| "radiology": ["radiology_vqa", "report_generation", "segmentation"], | |
| "pathology": ["pathology_classification", "slide_analysis"], | |
| "laboratory": ["lab_normalization", "result_interpretation"], | |
| "cardiology": ["ecg_analysis", "cardiac_imaging"], | |
| "discharge_summary": ["clinical_summarization", "coding_extraction"], | |
| "operative_note": ["procedure_extraction", "coding"], | |
| "clinical_notes": ["clinical_ner", "summarization"], | |
| "consultation": ["clinical_ner", "diagnosis_extraction"], | |
| "medication_list": ["medication_extraction", "drug_interaction"] | |
| } | |
| # Set primary models | |
| hints["primary_models"] = type_to_models.get(primary_type, ["general"]) | |
| # Set secondary models | |
| for sec_type in secondary_types: | |
| if sec_type in type_to_models: | |
| hints["secondary_models"].extend(type_to_models[sec_type]) | |
| # Special processing hints | |
| if primary_type == "radiology": | |
| hints["extract_images"] = True | |
| hints["priority"] = "high" | |
| if primary_type == "laboratory": | |
| hints["extract_tables"] = True | |
| if primary_type == "pathology": | |
| hints["extract_images"] = True | |
| # Check if document has images | |
| if pdf_content.get("images"): | |
| hints["has_images"] = True | |
| # Check if document has tables | |
| if pdf_content.get("tables"): | |
| hints["has_tables"] = True | |
| return hints | |