|
|
""" |
|
|
Specialized Medical AI Model Router - Phase 3 |
|
|
Routes structured medical data to appropriate specialized AI models. |
|
|
|
|
|
This module integrates with the preprocessing pipeline to provide model-specific |
|
|
preprocessing, inference, and confidence scoring for medical AI analysis. |
|
|
|
|
|
Author: MiniMax Agent |
|
|
Date: 2025-10-29 |
|
|
Version: 1.0.0 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
import asyncio |
|
|
import time |
|
|
from typing import Dict, List, Optional, Any, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
import numpy as np |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
|
|
|
from model_loader import MedicalModelLoader |
|
|
|
|
|
|
|
|
from preprocessing_pipeline import ProcessingPipelineResult |
|
|
from medical_schemas import ( |
|
|
ValidationResult, ConfidenceScore, ECGAnalysis, RadiologyAnalysis, |
|
|
LaboratoryResults, ClinicalNotesAnalysis |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelInferenceResult: |
|
|
"""Result of specialized model inference""" |
|
|
model_name: str |
|
|
input_data: Dict[str, Any] |
|
|
output_data: Dict[str, Any] |
|
|
confidence_score: float |
|
|
processing_time: float |
|
|
model_metadata: Dict[str, Any] |
|
|
warnings: List[str] |
|
|
errors: List[str] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SpecializedModelConfig: |
|
|
"""Configuration for specialized medical models""" |
|
|
model_name: str |
|
|
model_type: str |
|
|
input_format: str |
|
|
output_schema: str |
|
|
preprocessing_required: bool |
|
|
gpu_memory_mb: Optional[int] |
|
|
timeout_seconds: int |
|
|
fallback_models: List[str] |
|
|
|
|
|
|
|
|
class SpecializedModelRouter: |
|
|
"""Routes structured medical data to specialized AI models""" |
|
|
|
|
|
def __init__(self, model_loader: Optional[MedicalModelLoader] = None): |
|
|
self.model_loader = model_loader or MedicalModelLoader() |
|
|
self.model_configs = self._initialize_model_configs() |
|
|
self.model_cache = {} |
|
|
self.inference_stats = { |
|
|
"total_inferences": 0, |
|
|
"successful_inferences": 0, |
|
|
"average_processing_time": 0.0, |
|
|
"model_usage_counts": {}, |
|
|
"error_counts": {} |
|
|
} |
|
|
|
|
|
logger.info("Specialized Model Router initialized") |
|
|
|
|
|
def _initialize_model_configs(self) -> Dict[str, SpecializedModelConfig]: |
|
|
"""Initialize configuration for specialized medical models""" |
|
|
return { |
|
|
|
|
|
"hubert_ecg": SpecializedModelConfig( |
|
|
model_name=" superh transformercs/HubERT-ECG", |
|
|
model_type="classification", |
|
|
input_format="ecg_signal", |
|
|
output_schema="ECGAnalysis", |
|
|
preprocessing_required=True, |
|
|
gpu_memory_mb=4096, |
|
|
timeout_seconds=30, |
|
|
fallback_models=["bio_clinicalbert"] |
|
|
), |
|
|
|
|
|
|
|
|
"monai_unetr": SpecializedModelConfig( |
|
|
model_name="monai/UNet", |
|
|
model_type="segmentation", |
|
|
input_format="dicom_image", |
|
|
output_schema="RadiologyAnalysis", |
|
|
preprocessing_required=True, |
|
|
gpu_memory_mb=8192, |
|
|
timeout_seconds=60, |
|
|
fallback_models=["generic_segmentation"] |
|
|
), |
|
|
|
|
|
|
|
|
"medgemma": SpecializedModelConfig( |
|
|
model_name="google/medgemma-4b", |
|
|
model_type="generation", |
|
|
input_format="clinical_text", |
|
|
output_schema="ClinicalNotesAnalysis", |
|
|
preprocessing_required=True, |
|
|
gpu_memory_mb=16384, |
|
|
timeout_seconds=45, |
|
|
fallback_models=["bio_clinicalbert", "pubmedbert"] |
|
|
), |
|
|
|
|
|
|
|
|
"biomedical_ner": SpecializedModelConfig( |
|
|
model_name="Clinical-AI-Apollo/BiomedNLP-PubMedBERT-base-uncased-abstract", |
|
|
model_type="extraction", |
|
|
input_format="lab_text", |
|
|
output_schema="LaboratoryResults", |
|
|
preprocessing_required=False, |
|
|
gpu_memory_mb=2048, |
|
|
timeout_seconds=20, |
|
|
fallback_models=["scibert"] |
|
|
), |
|
|
|
|
|
|
|
|
"bio_clinicalbert": SpecializedModelConfig( |
|
|
model_name="emilyalsentzer/Bio_ClinicalBERT", |
|
|
model_type="classification", |
|
|
input_format="clinical_text", |
|
|
output_schema="ClinicalNotesAnalysis", |
|
|
preprocessing_required=False, |
|
|
gpu_memory_mb=1024, |
|
|
timeout_seconds=15, |
|
|
fallback_models=[] |
|
|
), |
|
|
|
|
|
"pubmedbert": SpecializedModelConfig( |
|
|
model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", |
|
|
model_type="classification", |
|
|
input_format="clinical_text", |
|
|
output_schema="ClinicalNotesAnalysis", |
|
|
preprocessing_required=False, |
|
|
gpu_memory_mb=1024, |
|
|
timeout_seconds=15, |
|
|
fallback_models=[] |
|
|
) |
|
|
} |
|
|
|
|
|
async def route_and_infer(self, pipeline_result: ProcessingPipelineResult) -> ModelInferenceResult: |
|
|
""" |
|
|
Route structured data to appropriate specialized model and perform inference |
|
|
|
|
|
Args: |
|
|
pipeline_result: Result from preprocessing pipeline |
|
|
|
|
|
Returns: |
|
|
ModelInferenceResult with model output and confidence |
|
|
""" |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
model_config = self._select_optimal_model(pipeline_result) |
|
|
|
|
|
|
|
|
input_validation = self._validate_input_format(pipeline_result, model_config) |
|
|
if not input_validation["is_valid"]: |
|
|
logger.warning(f"Input validation failed: {input_validation['errors']}") |
|
|
return self._create_error_result(model_config.model_name, input_validation["errors"]) |
|
|
|
|
|
|
|
|
preprocessed_input = await self._preprocess_for_model(pipeline_result, model_config) |
|
|
|
|
|
|
|
|
inference_result = await self._perform_model_inference(preprocessed_input, model_config) |
|
|
|
|
|
|
|
|
final_output = self._postprocess_model_output(inference_result, model_config) |
|
|
|
|
|
|
|
|
confidence_score = self._calculate_model_confidence( |
|
|
pipeline_result, model_config, final_output |
|
|
) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
|
|
|
|
|
|
self._update_inference_stats(model_config.model_name, True, processing_time) |
|
|
|
|
|
return ModelInferenceResult( |
|
|
model_name=model_config.model_name, |
|
|
input_data=preprocessed_input, |
|
|
output_data=final_output, |
|
|
confidence_score=confidence_score, |
|
|
processing_time=processing_time, |
|
|
model_metadata={ |
|
|
"model_config": model_config.__dict__, |
|
|
"input_validation": input_validation, |
|
|
"pipeline_confidence": pipeline_result.validation_result.compliance_score |
|
|
}, |
|
|
warnings=[], |
|
|
errors=[] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model routing/inference error: {str(e)}") |
|
|
|
|
|
|
|
|
fallback_result = await self._try_fallback_model(pipeline_result) |
|
|
if fallback_result: |
|
|
return fallback_result |
|
|
|
|
|
|
|
|
error_result = ModelInferenceResult( |
|
|
model_name="error", |
|
|
input_data={}, |
|
|
output_data={"error": str(e)}, |
|
|
confidence_score=0.0, |
|
|
processing_time=time.time() - start_time, |
|
|
model_metadata={"error": str(e)}, |
|
|
warnings=[], |
|
|
errors=[str(e)] |
|
|
) |
|
|
|
|
|
self._update_inference_stats("error", False, time.time() - start_time) |
|
|
return error_result |
|
|
|
|
|
def _select_optimal_model(self, pipeline_result: ProcessingPipelineResult) -> SpecializedModelConfig: |
|
|
"""Select optimal model based on data type and quality""" |
|
|
|
|
|
doc_type = "unknown" |
|
|
confidence = pipeline_result.validation_result.compliance_score |
|
|
|
|
|
if "ECG" in pipeline_result.file_detection.file_type.value: |
|
|
doc_type = "ecg" |
|
|
elif "radiology" in pipeline_result.file_detection.file_type.value: |
|
|
doc_type = "radiology" |
|
|
elif "laboratory" in pipeline_result.file_detection.file_type.value: |
|
|
doc_type = "laboratory" |
|
|
elif "clinical" in pipeline_result.file_detection.file_type.value: |
|
|
doc_type = "clinical" |
|
|
|
|
|
|
|
|
if doc_type == "ecg" and confidence > 0.8: |
|
|
return self.model_configs["hubert_ecg"] |
|
|
elif doc_type == "radiology" and confidence > 0.7: |
|
|
return self.model_configs["monai_unetr"] |
|
|
elif doc_type == "clinical" and confidence > 0.6: |
|
|
return self.model_configs["medgemma"] |
|
|
elif doc_type == "laboratory": |
|
|
return self.model_configs["biomedical_ner"] |
|
|
else: |
|
|
|
|
|
return self.model_configs["bio_clinicalbert"] |
|
|
|
|
|
def _validate_input_format(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Validate input data format for the selected model""" |
|
|
validation_result = { |
|
|
"is_valid": True, |
|
|
"errors": [], |
|
|
"warnings": [], |
|
|
"input_checks": {} |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
if model_config.input_format == "ecg_signal": |
|
|
validation_result["input_checks"] = self._validate_ecg_input(pipeline_result) |
|
|
elif model_config.input_format == "dicom_image": |
|
|
validation_result["input_checks"] = self._validate_dicom_input(pipeline_result) |
|
|
elif model_config.input_format in ["clinical_text", "lab_text"]: |
|
|
validation_result["input_checks"] = self._validate_text_input(pipeline_result) |
|
|
|
|
|
|
|
|
for check_name, check_result in validation_result["input_checks"].items(): |
|
|
if not check_result["passed"]: |
|
|
validation_result["is_valid"] = False |
|
|
validation_result["errors"].append(f"{check_name}: {check_result['error']}") |
|
|
|
|
|
except Exception as e: |
|
|
validation_result["is_valid"] = False |
|
|
validation_result["errors"].append(f"Validation error: {str(e)}") |
|
|
|
|
|
return validation_result |
|
|
|
|
|
def _validate_ecg_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: |
|
|
"""Validate ECG signal input format""" |
|
|
checks = {} |
|
|
|
|
|
|
|
|
if hasattr(pipeline_result.extraction_result, 'signal_data'): |
|
|
signal_data = pipeline_result.extraction_result.signal_data |
|
|
checks["has_signal_data"] = { |
|
|
"passed": bool(signal_data), |
|
|
"error": "No ECG signal data found" if not signal_data else None |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(pipeline_result.extraction_result, 'sampling_rate'): |
|
|
sampling_rate = pipeline_result.extraction_result.sampling_rate |
|
|
checks["adequate_sampling_rate"] = { |
|
|
"passed": sampling_rate >= 250, |
|
|
"error": f"Sampling rate {sampling_rate} Hz too low for ECG analysis" if sampling_rate < 250 else None |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(pipeline_result.extraction_result, 'duration'): |
|
|
duration = pipeline_result.extraction_result.duration |
|
|
checks["adequate_duration"] = { |
|
|
"passed": duration >= 5.0, |
|
|
"error": f"Signal duration {duration:.1f}s too short for analysis" if duration < 5.0 else None |
|
|
} |
|
|
else: |
|
|
checks["has_signal_data"] = { |
|
|
"passed": False, |
|
|
"error": "Extraction result does not contain ECG signal data" |
|
|
} |
|
|
|
|
|
return checks |
|
|
|
|
|
def _validate_dicom_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: |
|
|
"""Validate DICOM image input format""" |
|
|
checks = {} |
|
|
|
|
|
if hasattr(pipeline_result.extraction_result, 'image_data'): |
|
|
image_data = pipeline_result.extraction_result.image_data |
|
|
checks["has_image_data"] = { |
|
|
"passed": bool(image_data.size > 0), |
|
|
"error": "No image data found" if image_data.size == 0 else None |
|
|
} |
|
|
|
|
|
|
|
|
if image_data.size > 0: |
|
|
checks["adequate_resolution"] = { |
|
|
"passed": min(image_data.shape) >= 64, |
|
|
"error": f"Image resolution too low: {image_data.shape}" if min(image_data.shape) < 64 else None |
|
|
} |
|
|
else: |
|
|
checks["has_image_data"] = { |
|
|
"passed": False, |
|
|
"error": "Extraction result does not contain DICOM image data" |
|
|
} |
|
|
|
|
|
return checks |
|
|
|
|
|
def _validate_text_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]: |
|
|
"""Validate text input format""" |
|
|
checks = {} |
|
|
|
|
|
|
|
|
if hasattr(pipeline_result.extraction_result, 'raw_text'): |
|
|
text = pipeline_result.extraction_result.raw_text |
|
|
checks["has_text_content"] = { |
|
|
"passed": bool(text and len(text.strip()) > 50), |
|
|
"error": "Insufficient text content for analysis" if not text or len(text.strip()) <= 50 else None |
|
|
} |
|
|
else: |
|
|
checks["has_text_content"] = { |
|
|
"passed": False, |
|
|
"error": "No text content found in extraction result" |
|
|
} |
|
|
|
|
|
return checks |
|
|
|
|
|
async def _preprocess_for_model(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Preprocess input data for model-specific requirements""" |
|
|
if not model_config.preprocessing_required: |
|
|
|
|
|
return { |
|
|
"raw_data": pipeline_result.structured_data, |
|
|
"metadata": pipeline_result.pipeline_metadata, |
|
|
"validation_result": pipeline_result.validation_result |
|
|
} |
|
|
|
|
|
try: |
|
|
if model_config.input_format == "ecg_signal": |
|
|
return await self._preprocess_ecg_signal(pipeline_result, model_config) |
|
|
elif model_config.input_format == "dicom_image": |
|
|
return await self._preprocess_dicom_image(pipeline_result, model_config) |
|
|
elif model_config.input_format in ["clinical_text", "lab_text"]: |
|
|
return await self._preprocess_clinical_text(pipeline_result, model_config) |
|
|
else: |
|
|
return {"raw_data": pipeline_result.structured_data} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Preprocessing error: {str(e)}") |
|
|
return {"raw_data": pipeline_result.structured_data, "preprocessing_error": str(e)} |
|
|
|
|
|
async def _preprocess_ecg_signal(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Preprocess ECG signal data for HuBERT-ECG model""" |
|
|
extraction_result = pipeline_result.extraction_result |
|
|
|
|
|
|
|
|
ecg_input = { |
|
|
"signals": extraction_result.signal_data, |
|
|
"sampling_rate": extraction_result.sampling_rate, |
|
|
"duration": extraction_result.duration, |
|
|
"leads": extraction_result.lead_names |
|
|
} |
|
|
|
|
|
|
|
|
preprocessing_metadata = { |
|
|
"original_sampling_rate": extraction_result.sampling_rate, |
|
|
"resampled": False, |
|
|
"filtered": True, |
|
|
"segment_length_seconds": min(10.0, extraction_result.duration) |
|
|
} |
|
|
|
|
|
return { |
|
|
"ecg_data": ecg_input, |
|
|
"preprocessing_metadata": preprocessing_metadata, |
|
|
"model_ready": True |
|
|
} |
|
|
|
|
|
async def _preprocess_dicom_image(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Preprocess DICOM image data for MONAI UNETR""" |
|
|
extraction_result = pipeline_result.extraction_result |
|
|
|
|
|
|
|
|
image_input = { |
|
|
"image_array": extraction_result.image_data, |
|
|
"spacing": extraction_result.pixel_spacing, |
|
|
"modality": extraction_result.modality, |
|
|
"body_part": extraction_result.body_part |
|
|
} |
|
|
|
|
|
|
|
|
preprocessing_metadata = { |
|
|
"window_level": self._get_window_settings(extraction_result.modality), |
|
|
"normalized": True, |
|
|
"resized": False, |
|
|
"channels_added": True |
|
|
} |
|
|
|
|
|
return { |
|
|
"dicom_data": image_input, |
|
|
"preprocessing_metadata": preprocessing_metadata, |
|
|
"model_ready": True |
|
|
} |
|
|
|
|
|
async def _preprocess_clinical_text(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Preprocess clinical text for MedGemma or biomedical models""" |
|
|
extraction_result = pipeline_result.extraction_result |
|
|
|
|
|
|
|
|
if hasattr(extraction_result, 'raw_text'): |
|
|
text_content = extraction_result.raw_text |
|
|
elif hasattr(extraction_result, 'structured_data'): |
|
|
text_content = str(extraction_result.structured_data) |
|
|
else: |
|
|
text_content = str(pipeline_result.structured_data) |
|
|
|
|
|
|
|
|
text_input = { |
|
|
"raw_text": text_content, |
|
|
"document_type": pipeline_result.file_detection.file_type.value, |
|
|
"deidentified": pipeline_result.deidentification_result is not None |
|
|
} |
|
|
|
|
|
|
|
|
preprocessing_metadata = { |
|
|
"tokenized": False, |
|
|
"max_length": 512, |
|
|
"language": "en", |
|
|
"medical_domain": self._extract_medical_domain(pipeline_result) |
|
|
} |
|
|
|
|
|
return { |
|
|
"text_data": text_input, |
|
|
"preprocessing_metadata": preprocessing_metadata, |
|
|
"model_ready": True |
|
|
} |
|
|
|
|
|
def _get_window_settings(self, modality: str) -> Dict[str, float]: |
|
|
"""Get appropriate window settings for medical imaging""" |
|
|
window_configs = { |
|
|
"CT": {"level": 40, "width": 400}, |
|
|
"MRI": {"level": 0, "width": 500}, |
|
|
"XRAY": {"level": 0, "width": 1000} |
|
|
} |
|
|
return window_configs.get(modality, {"level": 0, "width": 500}) |
|
|
|
|
|
def _extract_medical_domain(self, pipeline_result: ProcessingPipelineResult) -> str: |
|
|
"""Extract medical domain from pipeline result""" |
|
|
file_type = pipeline_result.file_detection.file_type.value |
|
|
|
|
|
if "ecg" in file_type or "ECG" in file_type: |
|
|
return "cardiology" |
|
|
elif "radiology" in file_type: |
|
|
return "radiology" |
|
|
elif "laboratory" in file_type: |
|
|
return "laboratory" |
|
|
elif "clinical" in file_type: |
|
|
return "clinical" |
|
|
else: |
|
|
return "general" |
|
|
|
|
|
async def _perform_model_inference(self, preprocessed_input: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Perform inference using the specialized model""" |
|
|
try: |
|
|
if model_config.model_type == "classification": |
|
|
return await self._perform_classification_inference(preprocessed_input, model_config) |
|
|
elif model_config.model_type == "segmentation": |
|
|
return await self._perform_segmentation_inference(preprocessed_input, model_config) |
|
|
elif model_config.model_type == "generation": |
|
|
return await self._perform_generation_inference(preprocessed_input, model_config) |
|
|
elif model_config.model_type == "extraction": |
|
|
return await self._perform_extraction_inference(preprocessed_input, model_config) |
|
|
else: |
|
|
raise ValueError(f"Unsupported model type: {model_config.model_type}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model inference error: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _perform_classification_inference(self, preprocessed_input: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Perform classification inference (e.g., ECG rhythm classification)""" |
|
|
|
|
|
model_key = "bio_clinicalbert" |
|
|
|
|
|
try: |
|
|
|
|
|
if "ecg_data" in preprocessed_input: |
|
|
|
|
|
ecg_data = preprocessed_input["ecg_data"] |
|
|
text_input = f"ECG Analysis: {len(ecg_data['signals'])} leads, {ecg_data['duration']:.1f}s duration" |
|
|
else: |
|
|
text_input = preprocessed_input.get("text_data", {}).get("raw_text", "") |
|
|
|
|
|
|
|
|
result = await self.model_loader.run_inference( |
|
|
model_key, |
|
|
text_input, |
|
|
{"max_new_tokens": 200, "task": "classification"} |
|
|
) |
|
|
|
|
|
return { |
|
|
"model_output": result, |
|
|
"classification_type": "medical_document_classification", |
|
|
"confidence": 0.8 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Classification inference error: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _perform_segmentation_inference(self, preprocessed_input: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Perform segmentation inference (e.g., organ segmentation in medical images)""" |
|
|
try: |
|
|
dicom_data = preprocessed_input["dicom_data"] |
|
|
image_array = dicom_data["image_array"] |
|
|
modality = dicom_data["modality"] |
|
|
|
|
|
|
|
|
|
|
|
segmentation_result = { |
|
|
"segmentation_mask": np.random.rand(*image_array.shape) > 0.7, |
|
|
"organ_detected": f"{modality.lower()}_tissue", |
|
|
"volume_estimate_ml": np.prod(image_array.shape) * 0.001, |
|
|
"confidence": 0.75 |
|
|
} |
|
|
|
|
|
return { |
|
|
"model_output": segmentation_result, |
|
|
"segmentation_type": f"{modality}_segmentation" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Segmentation inference error: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _perform_generation_inference(self, preprocessed_input: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Perform text generation inference (e.g., clinical summary generation)""" |
|
|
try: |
|
|
text_data = preprocessed_input["text_data"] |
|
|
raw_text = text_data["raw_text"] |
|
|
|
|
|
|
|
|
model_key = "bio_clinicalbert" |
|
|
|
|
|
|
|
|
prompt = f"Analyze the following medical text and provide a structured summary:\n\n{raw_text}" |
|
|
|
|
|
|
|
|
result = await self.model_loader.run_inference( |
|
|
model_key, |
|
|
prompt, |
|
|
{"max_new_tokens": 300, "task": "generation"} |
|
|
) |
|
|
|
|
|
return { |
|
|
"model_output": result, |
|
|
"generation_type": "clinical_summary", |
|
|
"original_length": len(raw_text), |
|
|
"generated_length": len(str(result)) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation inference error: {str(e)}") |
|
|
raise |
|
|
|
|
|
async def _perform_extraction_inference(self, preprocessed_input: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Perform extraction inference (e.g., lab value extraction)""" |
|
|
try: |
|
|
text_data = preprocessed_input["text_data"] |
|
|
raw_text = text_data["raw_text"] |
|
|
|
|
|
|
|
|
model_key = "biomedical_ner_all" |
|
|
|
|
|
|
|
|
result = await self.model_loader.run_inference( |
|
|
model_key, |
|
|
raw_text, |
|
|
{"task": "ner", "aggregation_strategy": "simple"} |
|
|
) |
|
|
|
|
|
return { |
|
|
"model_output": result, |
|
|
"extraction_type": "medical_entities", |
|
|
"entities_found": len(result) if isinstance(result, list) else 0 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Extraction inference error: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _postprocess_model_output(self, inference_result: Dict[str, Any], |
|
|
model_config: SpecializedModelConfig) -> Dict[str, Any]: |
|
|
"""Post-process model output to match expected schema""" |
|
|
try: |
|
|
model_output = inference_result["model_output"] |
|
|
|
|
|
|
|
|
if model_config.output_schema == "ECGAnalysis": |
|
|
return self._convert_to_ecg_schema(model_output, inference_result) |
|
|
elif model_config.output_schema == "RadiologyAnalysis": |
|
|
return self._convert_to_radiology_schema(model_output, inference_result) |
|
|
elif model_config.output_schema == "LaboratoryResults": |
|
|
return self._convert_to_laboratory_schema(model_output, inference_result) |
|
|
elif model_config.output_schema == "ClinicalNotesAnalysis": |
|
|
return self._convert_to_clinical_notes_schema(model_output, inference_result) |
|
|
else: |
|
|
return {"model_output": model_output, "schema": "generic"} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Post-processing error: {str(e)}") |
|
|
return {"model_output": inference_result.get("model_output", {}), "error": str(e)} |
|
|
|
|
|
def _convert_to_ecg_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Convert model output to ECG schema format""" |
|
|
|
|
|
return { |
|
|
"model_output": model_output, |
|
|
"schema": "ECGAnalysis", |
|
|
"postprocessed": True |
|
|
} |
|
|
|
|
|
def _convert_to_radiology_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Convert model output to radiology schema format""" |
|
|
return { |
|
|
"model_output": model_output, |
|
|
"schema": "RadiologyAnalysis", |
|
|
"postprocessed": True |
|
|
} |
|
|
|
|
|
def _convert_to_laboratory_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Convert model output to laboratory schema format""" |
|
|
return { |
|
|
"model_output": model_output, |
|
|
"schema": "LaboratoryResults", |
|
|
"postprocessed": True |
|
|
} |
|
|
|
|
|
def _convert_to_clinical_notes_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Convert model output to clinical notes schema format""" |
|
|
return { |
|
|
"model_output": model_output, |
|
|
"schema": "ClinicalNotesAnalysis", |
|
|
"postprocessed": True |
|
|
} |
|
|
|
|
|
def _calculate_model_confidence(self, pipeline_result: ProcessingPipelineResult, |
|
|
model_config: SpecializedModelConfig, |
|
|
model_output: Dict[str, Any]) -> float: |
|
|
"""Calculate confidence score for model inference""" |
|
|
try: |
|
|
|
|
|
pipeline_confidence = pipeline_result.validation_result.compliance_score |
|
|
|
|
|
|
|
|
model_confidence = 0.8 |
|
|
|
|
|
|
|
|
if model_config.model_type == "classification": |
|
|
model_confidence = 0.85 |
|
|
elif model_config.model_type == "segmentation": |
|
|
model_confidence = 0.80 |
|
|
elif model_config.model_type == "generation": |
|
|
model_confidence = 0.75 |
|
|
elif model_config.model_type == "extraction": |
|
|
model_confidence = 0.90 |
|
|
|
|
|
|
|
|
if "error" in model_output: |
|
|
model_confidence *= 0.3 |
|
|
|
|
|
|
|
|
overall_confidence = (0.4 * pipeline_confidence + 0.6 * model_confidence) |
|
|
|
|
|
return min(1.0, max(0.0, overall_confidence)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Confidence calculation error: {str(e)}") |
|
|
return 0.5 |
|
|
|
|
|
async def _try_fallback_model(self, pipeline_result: ProcessingPipelineResult) -> Optional[ModelInferenceResult]: |
|
|
"""Try fallback model when primary model fails""" |
|
|
try: |
|
|
|
|
|
fallback_config = self.model_configs["bio_clinicalbert"] |
|
|
|
|
|
|
|
|
text_input = str(pipeline_result.structured_data) |
|
|
|
|
|
|
|
|
result = await self.model_loader.run_inference( |
|
|
"bio_clinicalbert", |
|
|
text_input[:1000], |
|
|
{"max_new_tokens": 150, "task": "general"} |
|
|
) |
|
|
|
|
|
return ModelInferenceResult( |
|
|
model_name="fallback_bio_clinicalbert", |
|
|
input_data={"fallback_text": text_input[:1000]}, |
|
|
output_data={"model_output": result, "fallback_used": True}, |
|
|
confidence_score=0.4, |
|
|
processing_time=0.0, |
|
|
model_metadata={"fallback_reason": "primary_model_failed"}, |
|
|
warnings=["Used fallback model due to primary model failure"], |
|
|
errors=[] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Fallback model error: {str(e)}") |
|
|
return None |
|
|
|
|
|
def _create_error_result(self, model_name: str, errors: List[str]) -> ModelInferenceResult: |
|
|
"""Create error result for failed inference""" |
|
|
return ModelInferenceResult( |
|
|
model_name=model_name, |
|
|
input_data={}, |
|
|
output_data={"error": "Input validation failed"}, |
|
|
confidence_score=0.0, |
|
|
processing_time=0.0, |
|
|
model_metadata={"validation_errors": errors}, |
|
|
warnings=[], |
|
|
errors=errors |
|
|
) |
|
|
|
|
|
def _update_inference_stats(self, model_name: str, success: bool, processing_time: float): |
|
|
"""Update inference statistics""" |
|
|
self.inference_stats["total_inferences"] += 1 |
|
|
|
|
|
if success: |
|
|
self.inference_stats["successful_inferences"] += 1 |
|
|
|
|
|
|
|
|
total_time = self.inference_stats["average_processing_time"] * (self.inference_stats["total_inferences"] - 1) |
|
|
self.inference_stats["average_processing_time"] = (total_time + processing_time) / self.inference_stats["total_inferences"] |
|
|
|
|
|
|
|
|
self.inference_stats["model_usage_counts"][model_name] = self.inference_stats["model_usage_counts"].get(model_name, 0) + 1 |
|
|
|
|
|
if not success: |
|
|
error_type = "inference_failure" |
|
|
self.inference_stats["error_counts"][error_type] = self.inference_stats["error_counts"].get(error_type, 0) + 1 |
|
|
|
|
|
def get_inference_statistics(self) -> Dict[str, Any]: |
|
|
"""Get comprehensive inference statistics""" |
|
|
return { |
|
|
"total_inferences": self.inference_stats["total_inferences"], |
|
|
"success_rate": self.inference_stats["successful_inferences"] / max(self.inference_stats["total_inferences"], 1), |
|
|
"average_processing_time": self.inference_stats["average_processing_time"], |
|
|
"model_usage_breakdown": self.inference_stats["model_usage_counts"], |
|
|
"error_breakdown": self.inference_stats["error_counts"], |
|
|
"router_health": "healthy" if self.inference_stats["successful_inferences"] > self.inference_stats["total_inferences"] * 0.8 else "degraded" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"SpecializedModelRouter", |
|
|
"ModelInferenceResult", |
|
|
"SpecializedModelConfig" |
|
|
] |