Spaces:
Sleeping
Sleeping
| """ | |
| Medical Data Schemas - Phase 1 Implementation | |
| Canonical JSON schemas for medical data modalities with validation rules and confidence scoring. | |
| This module defines the structured data contracts that ensure proper input/output | |
| formats across the medical AI pipeline, replacing unstructured PDF processing. | |
| Author: MiniMax Agent | |
| Date: 2025-10-29 | |
| Version: 1.0.0 | |
| """ | |
| from typing import List, Optional, Dict, Any, Union, Literal | |
| from pydantic import BaseModel, Field, validator, confloat | |
| from datetime import datetime | |
| import uuid | |
| import numpy as np | |
| # ================================ | |
| # BASE TYPES AND ENUMS | |
| # ================================ | |
| class ConfidenceScore(BaseModel): | |
| """Composite confidence scoring for medical data extraction and analysis""" | |
| extraction_confidence: confloat(ge=0.0, le=1.0) = Field( | |
| description="Confidence in data extraction from source document (0.0-1.0)" | |
| ) | |
| model_confidence: confloat(ge=0.0, le=1.0) = Field( | |
| description="Confidence in AI model analysis/output (0.0-1.0)" | |
| ) | |
| data_quality: confloat(ge=0.0, le=1.0) = Field( | |
| description="Quality of source data (completeness, clarity, resolution) (0.0-1.0)" | |
| ) | |
| def overall_confidence(self) -> float: | |
| """Calculate composite confidence using weighted formula: 0.5 * extraction + 0.3 * model + 0.2 * quality""" | |
| return (0.5 * self.extraction_confidence + | |
| 0.3 * self.model_confidence + | |
| 0.2 * self.data_quality) | |
| def requires_review(self) -> bool: | |
| """Determine if this data requires human review based on confidence thresholds""" | |
| overall = self.overall_confidence | |
| return overall < 0.85 # Below 85% requires review | |
| class MedicalDocumentMetadata(BaseModel): | |
| """Common metadata for all medical documents""" | |
| document_id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
| source_type: Literal["ECG", "radiology", "laboratory", "clinical_notes", "unknown"] | |
| document_date: Optional[datetime] = None | |
| patient_id_hash: Optional[str] = None # Anonymized identifier | |
| facility: Optional[str] = None | |
| provider: Optional[str] = None | |
| extraction_timestamp: datetime = Field(default_factory=datetime.now) | |
| data_completeness: confloat(ge=0.0, le=1.0) = Field( | |
| description="Overall completeness of extracted data (0.0-1.0)" | |
| ) | |
| # ================================ | |
| # ECG SCHEMA (PHASE 1 PRIORITY) | |
| # ================================ | |
| class ECGSignalData(BaseModel): | |
| """ECG signal array data for rhythm analysis""" | |
| lead_names: List[str] = Field( | |
| description="List of ECG lead names (I, II, III, aVR, aVL, aVF, V1-V6)" | |
| ) | |
| sampling_rate_hz: int = Field(ge=100, le=1000, description="Sampling rate in Hz") | |
| signal_arrays: Dict[str, List[float]] = Field( | |
| description="Dictionary mapping lead names to signal arrays (mV values)" | |
| ) | |
| duration_seconds: float = Field(gt=0, description="Recording duration in seconds") | |
| num_samples: int = Field(gt=0, description="Number of samples per lead") | |
| def validate_signal_arrays(cls, v): | |
| """Ensure all lead arrays have consistent length and valid values""" | |
| if not v: | |
| raise ValueError("Signal arrays cannot be empty") | |
| expected_length = None | |
| for lead_name, signal in v.items(): | |
| if not isinstance(signal, list) or not signal: | |
| raise ValueError(f"Lead {lead_name} must be non-empty list") | |
| # Check for valid mV range (-5 to +5 mV) | |
| if any(abs(val) > 5.0 for val in signal): | |
| raise ValueError(f"Lead {lead_name} contains values outside valid ECG range (-5 to +5 mV)") | |
| # Ensure consistent array length | |
| if expected_length is None: | |
| expected_length = len(signal) | |
| elif len(signal) != expected_length: | |
| raise ValueError(f"All leads must have same array length") | |
| return v | |
| class ECGIntervals(BaseModel): | |
| """ECG timing intervals for arrhythmia detection""" | |
| pr_ms: Optional[float] = Field(None, ge=0, le=400, description="PR interval in milliseconds") | |
| qrs_ms: Optional[float] = Field(None, ge=0, le=200, description="QRS duration in milliseconds") | |
| qt_ms: Optional[float] = Field(None, ge=200, le=600, description="QT interval in milliseconds") | |
| qtc_ms: Optional[float] = Field(None, ge=200, le=600, description="QTc interval in milliseconds") | |
| rr_ms: Optional[float] = Field(None, ge=300, le=2000, description="RR interval in milliseconds") | |
| def is_bradycardia(self) -> Optional[bool]: | |
| """Detect bradycardia based on RR interval""" | |
| if self.rr_ms: | |
| return self.rr_ms > 1000 # HR < 60 bpm | |
| return None | |
| def is_tachycardia(self) -> Optional[bool]: | |
| """Detect tachycardia based on RR interval""" | |
| if self.rr_ms: | |
| return self.rr_ms < 600 # HR > 100 bpm | |
| return None | |
| class ECGRhythmClassification(BaseModel): | |
| """ECG rhythm classification results""" | |
| primary_rhythm: Optional[str] = Field(None, description="Primary rhythm classification") | |
| rhythm_confidence: Optional[confloat(ge=0.0, le=1.0)] = None | |
| arrhythmia_types: List[str] = Field(default_factory=list, description="Detected arrhythmia types") | |
| heart_rate_bpm: Optional[int] = Field(None, ge=20, le=300, description="Heart rate in beats per minute") | |
| heart_rate_regularity: Optional[Literal["regular", "irregular", "variable"]] = None | |
| class ECGArrhythmiaProbabilities(BaseModel): | |
| """Probabilities for specific arrhythmia conditions""" | |
| normal_rhythm: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Normal sinus rhythm probability") | |
| atrial_fibrillation: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Atrial fibrillation probability") | |
| atrial_flutter: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Atrial flutter probability") | |
| ventricular_tachycardia: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Ventricular tachycardia probability") | |
| heart_block: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Heart block probability") | |
| premature_beats: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Premature beat probability") | |
| class ECGDerivedFeatures(BaseModel): | |
| """ECG-derived clinical features for downstream analysis""" | |
| st_elevation_mm: Optional[Dict[str, float]] = Field(None, description="ST elevation by lead (mm)") | |
| st_depression_mm: Optional[Dict[str, float]] = Field(None, description="ST depression by lead (mm)") | |
| t_wave_abnormalities: List[str] = Field(default_factory=list, description="T-wave abnormality flags") | |
| q_wave_indicators: List[str] = Field(default_factory=list, description="Pathological Q-wave indicators") | |
| voltage_criteria: Optional[Dict[str, Any]] = Field(None, description="Voltage criteria for hypertrophy") | |
| axis_deviation: Optional[Literal["normal", "left", "right", "extreme"]] = None | |
| class ECGAnalysis(BaseModel): | |
| """Complete ECG analysis results with structured output""" | |
| metadata: MedicalDocumentMetadata = Field(source_type="ECG") | |
| signal_data: ECGSignalData | |
| intervals: ECGIntervals | |
| rhythm_classification: ECGRhythmClassification | |
| arrhythmia_probabilities: ECGArrhythmiaProbabilities | |
| derived_features: ECGDerivedFeatures | |
| confidence: ConfidenceScore | |
| clinical_summary: Optional[str] = Field(None, description="Human-readable clinical summary") | |
| recommendations: List[str] = Field(default_factory=list, description="Clinical recommendations") | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "metadata": { | |
| "document_id": "ecg-12345", | |
| "source_type": "ECG", | |
| "document_date": "2025-10-29T10:38:55Z", | |
| "facility": "General Hospital", | |
| "extraction_timestamp": "2025-10-29T10:38:55Z" | |
| }, | |
| "signal_data": { | |
| "lead_names": ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"], | |
| "sampling_rate_hz": 500, | |
| "duration_seconds": 10.0, | |
| "num_samples": 5000 | |
| }, | |
| "intervals": { | |
| "pr_ms": 160.0, | |
| "qrs_ms": 88.0, | |
| "qt_ms": 380.0, | |
| "qtc_ms": 420.0 | |
| }, | |
| "confidence": { | |
| "extraction_confidence": 0.92, | |
| "model_confidence": 0.89, | |
| "data_quality": 0.95, | |
| "overall_confidence": 0.917 | |
| } | |
| } | |
| } | |
| # ================================ | |
| # RADIOLOGY SCHEMA | |
| # ================================ | |
| class RadiologyImageReference(BaseModel): | |
| """Reference to radiology images with metadata""" | |
| image_id: str = Field(description="Unique image identifier") | |
| modality: Literal["CT", "MRI", "XRAY", "ULTRASOUND", "MAMMOGRAPHY", "NUCLEAR"] = Field( | |
| description="Imaging modality" | |
| ) | |
| body_part: str = Field(description="Anatomical region imaged") | |
| view_orientation: Optional[str] = Field(None, description="Image orientation/plane") | |
| slice_thickness_mm: Optional[float] = Field(None, description="Slice thickness in mm") | |
| resolution: Optional[Dict[str, int]] = Field(None, description="Image resolution (width, height)") | |
| class RadiologySegmentation(BaseModel): | |
| """Medical image segmentation results""" | |
| organ_name: str = Field(description="Name of segmented organ/structure") | |
| volume_ml: Optional[float] = Field(None, ge=0, description="Volume in milliliters") | |
| surface_area_cm2: Optional[float] = Field(None, ge=0, description="Surface area in cm²") | |
| mean_intensity: Optional[float] = Field(None, description="Mean pixel intensity") | |
| max_intensity: Optional[float] = Field(None, description="Maximum pixel intensity") | |
| lesions: List[Dict[str, Any]] = Field(default_factory=list, description="Detected lesions") | |
| class RadiologyFindings(BaseModel): | |
| """Structured radiology findings extraction""" | |
| findings_text: str = Field(description="Raw findings text from report") | |
| impression_text: str = Field(description="Impression/conclusion section") | |
| critical_findings: List[str] = Field(default_factory=list, description="Urgent/critical findings") | |
| incidental_findings: List[str] = Field(default_factory=list, description="Incidental findings") | |
| comparison_prior: Optional[str] = Field(None, description="Comparison with prior studies") | |
| technique_description: Optional[str] = Field(None, description="Imaging technique details") | |
| class RadiologyMetrics(BaseModel): | |
| """Quantitative metrics from imaging analysis""" | |
| organ_volumes: Dict[str, float] = Field(default_factory=dict, description="Organ volumes in ml") | |
| lesion_measurements: List[Dict[str, float]] = Field( | |
| default_factory=list, | |
| description="Lesion size measurements" | |
| ) | |
| enhancement_patterns: List[str] = Field(default_factory=list, description="Contrast enhancement patterns") | |
| calcification_scores: Dict[str, float] = Field(default_factory=dict, description="Calcification severity scores") | |
| tissue_density: Optional[Dict[str, float]] = Field(None, description="Tissue density measurements") | |
| class RadiologyAnalysis(BaseModel): | |
| """Complete radiology analysis results""" | |
| metadata: MedicalDocumentMetadata = Field(source_type="radiology") | |
| image_references: List[RadiologyImageReference] | |
| findings: RadiologyFindings | |
| segmentations: List[RadiologySegmentation] = Field(default_factory=list) | |
| metrics: RadiologyMetrics | |
| confidence: ConfidenceScore | |
| criticality_level: Literal["routine", "urgent", "stat"] = Field(default="routine") | |
| follow_up_recommendations: List[str] = Field(default_factory=list) | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "metadata": { | |
| "document_id": "rad-67890", | |
| "source_type": "radiology", | |
| "document_date": "2025-10-29T10:38:55Z", | |
| "facility": "Imaging Center" | |
| }, | |
| "findings": { | |
| "findings_text": "Chest CT shows bilateral pulmonary nodules...", | |
| "impression_text": "Bilateral pulmonary nodules, likely benign", | |
| "critical_findings": [], | |
| "incidental_findings": ["Thyroid nodule", "Hepatic cyst"] | |
| }, | |
| "confidence": { | |
| "extraction_confidence": 0.88, | |
| "model_confidence": 0.91, | |
| "data_quality": 0.94 | |
| } | |
| } | |
| } | |
| # ================================ | |
| # LABORATORY SCHEMA | |
| # ================================ | |
| class LabTestResult(BaseModel): | |
| """Individual laboratory test result""" | |
| test_name: str = Field(description="Full name of the laboratory test") | |
| test_code: Optional[str] = Field(None, description="Standard test code (LOINC, etc.)") | |
| value: Optional[Union[float, str]] = Field(None, description="Test result value") | |
| unit: Optional[str] = Field(None, description="Units of measurement") | |
| reference_range_low: Optional[Union[float, str]] = Field(None, description="Lower reference limit") | |
| reference_range_high: Optional[Union[float, str]] = Field(None, description="Upper reference limit") | |
| flags: List[str] = Field(default_factory=list, description="Abnormal value flags (H, L, HH, LL)") | |
| test_date: Optional[datetime] = Field(None, description="Date/time test was performed") | |
| def is_abnormal(self) -> Optional[bool]: | |
| """Determine if test result is outside reference range""" | |
| if self.value is None or not isinstance(self.value, (int, float)): | |
| return None | |
| low = self.reference_range_low | |
| high = self.reference_range_high | |
| if low is None or high is None: | |
| return None | |
| try: | |
| low_val = float(low) if isinstance(low, str) else low | |
| high_val = float(high) if isinstance(high, str) else high | |
| value_val = float(self.value) | |
| return value_val < low_val or value_val > high_val | |
| except (ValueError, TypeError): | |
| return None | |
| class LaboratoryResults(BaseModel): | |
| """Complete laboratory results analysis""" | |
| metadata: MedicalDocumentMetadata = Field(source_type="laboratory") | |
| tests: List[LabTestResult] = Field(description="List of all test results") | |
| critical_values: List[str] = Field(default_factory=list, description="Critical values requiring immediate attention") | |
| panel_name: Optional[str] = Field(None, description="Name of test panel (CMP, CBC, etc.)") | |
| fasting_status: Optional[Literal["fasting", "non_fasting", "unknown"]] = None | |
| collection_date: Optional[datetime] = Field(None, description="Specimen collection date") | |
| confidence: ConfidenceScore | |
| abnormal_count: int = Field(default=0, description="Number of abnormal results") | |
| critical_count: int = Field(default=0, description="Number of critical results") | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "metadata": { | |
| "document_id": "lab-11111", | |
| "source_type": "laboratory", | |
| "document_date": "2025-10-29T10:38:55Z" | |
| }, | |
| "tests": [ | |
| { | |
| "test_name": "Glucose", | |
| "test_code": "2345-7", | |
| "value": 110.0, | |
| "unit": "mg/dL", | |
| "reference_range_low": 70.0, | |
| "reference_range_high": 99.0, | |
| "flags": ["H"] | |
| } | |
| ], | |
| "confidence": { | |
| "extraction_confidence": 0.95, | |
| "model_confidence": 0.92, | |
| "data_quality": 0.97 | |
| } | |
| } | |
| } | |
| # ================================ | |
| # CLINICAL NOTES SCHEMA | |
| # ================================ | |
| class ClinicalSection(BaseModel): | |
| """Structured clinical note sections""" | |
| section_type: Literal["chief_complaint", "history_present_illness", "past_medical_history", | |
| "medications", "allergies", "review_of_systems", "physical_exam", | |
| "assessment", "plan", "discharge_summary"] = Field( | |
| description="Type of clinical section" | |
| ) | |
| content: str = Field(description="Section content text") | |
| confidence: confloat(ge=0.0, le=1.0) = Field(description="Confidence in section extraction") | |
| class ClinicalEntity(BaseModel): | |
| """Medical entities extracted from clinical notes""" | |
| entity_type: Literal["diagnosis", "medication", "procedure", "symptom", "anatomy", "date", "lab_value"] = Field( | |
| description="Type of medical entity" | |
| ) | |
| text: str = Field(description="Entity text") | |
| value: Optional[Union[str, float]] = Field(None, description="Entity value if applicable") | |
| unit: Optional[str] = Field(None, description="Unit if applicable") | |
| confidence: confloat(ge=0.0, le=1.0) = Field(description="Confidence in entity extraction") | |
| context: Optional[str] = Field(None, description="Surrounding context for entity") | |
| class ClinicalNotesAnalysis(BaseModel): | |
| """Complete clinical notes analysis""" | |
| metadata: MedicalDocumentMetadata = Field(source_type="clinical_notes") | |
| sections: List[ClinicalSection] = Field(description="Extracted clinical sections") | |
| entities: List[ClinicalEntity] = Field(default_factory=list, description="Extracted medical entities") | |
| diagnoses: List[str] = Field(default_factory=list, description="Primary diagnoses") | |
| medications: List[str] = Field(default_factory=list, description="Current medications") | |
| procedures: List[str] = Field(default_factory=list, description="Recent procedures") | |
| confidence: ConfidenceScore | |
| note_type: Optional[Literal["progress_note", "consultation", "discharge_summary", "history_physical"]] = None | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "metadata": { | |
| "document_id": "note-22222", | |
| "source_type": "clinical_notes", | |
| "document_date": "2025-10-29T10:38:55Z" | |
| }, | |
| "sections": [ | |
| { | |
| "section_type": "chief_complaint", | |
| "content": "Patient presents with chest pain", | |
| "confidence": 0.98 | |
| } | |
| ], | |
| "entities": [ | |
| { | |
| "entity_type": "symptom", | |
| "text": "chest pain", | |
| "confidence": 0.95 | |
| } | |
| ], | |
| "confidence": { | |
| "extraction_confidence": 0.90, | |
| "model_confidence": 0.87, | |
| "data_quality": 0.93 | |
| } | |
| } | |
| } | |
| # ================================ | |
| # PIPELINE VALIDATION AND ROUTING | |
| # ================================ | |
| class DocumentClassification(BaseModel): | |
| """Document type classification with confidence""" | |
| predicted_type: Literal["ECG", "radiology", "laboratory", "clinical_notes", "unknown"] | |
| confidence: confloat(ge=0.0, le=1.0) | |
| alternative_types: List[Dict[str, float]] = Field(default_factory=list, description="Alternative classifications") | |
| requires_human_review: bool = Field(description="Whether human review is recommended") | |
| class ValidationResult(BaseModel): | |
| """Validation result for schema compliance""" | |
| is_valid: bool | |
| validation_errors: List[str] = Field(default_factory=list) | |
| warnings: List[str] = Field(default_factory=list) | |
| compliance_score: confloat(ge=0.0, le=1.0) = Field(description="Overall compliance score") | |
| def validate_document_schema(data: Dict[str, Any]) -> ValidationResult: | |
| """ | |
| Validate document against appropriate schema based on document type | |
| Args: | |
| data: Document data dictionary | |
| Returns: | |
| ValidationResult with validation status and any errors | |
| """ | |
| try: | |
| doc_type = data.get("metadata", {}).get("source_type", "unknown") | |
| if doc_type == "ECG": | |
| ECGAnalysis(**data) | |
| elif doc_type == "radiology": | |
| RadiologyAnalysis(**data) | |
| elif doc_type == "laboratory": | |
| LaboratoryResults(**data) | |
| elif doc_type == "clinical_notes": | |
| ClinicalNotesAnalysis(**data) | |
| else: | |
| return ValidationResult( | |
| is_valid=False, | |
| validation_errors=[f"Unknown document type: {doc_type}"], | |
| warnings=["Document type not recognized"] | |
| ) | |
| return ValidationResult( | |
| is_valid=True, | |
| compliance_score=1.0 | |
| ) | |
| except Exception as e: | |
| return ValidationResult( | |
| is_valid=False, | |
| validation_errors=[str(e)], | |
| compliance_score=0.0 | |
| ) | |
| def route_to_specialized_model(document_data: Dict[str, Any]) -> str: | |
| """ | |
| Route document to appropriate specialized model based on validated schema | |
| Args: | |
| document_data: Validated document data | |
| Returns: | |
| Model name for specialized processing | |
| """ | |
| doc_type = document_data.get("metadata", {}).get("source_type", "unknown") | |
| confidence = document_data.get("confidence", {}) | |
| # Route based on document type and confidence | |
| if doc_type == "ECG": | |
| if confidence.get("overall_confidence", 0) >= 0.85: | |
| return "hubert-ecg" # HuBERT-ECG for high-confidence ECG | |
| else: | |
| return "bio-clinicalbert" # Fallback for lower confidence | |
| elif doc_type == "radiology": | |
| return "monai-unetr" # MONAI UNETR for radiology segmentation | |
| elif doc_type == "laboratory": | |
| return "biomedical-ner" # Biomedical NER for lab value extraction | |
| elif doc_type == "clinical_notes": | |
| return "medgemma" # MedGemma for clinical text generation | |
| else: | |
| return "scibert" # Default fallback model | |
| # ================================ | |
| # EXPORT SCHEMAS FOR PIPELINE | |
| # ================================ | |
| __all__ = [ | |
| "ConfidenceScore", | |
| "MedicalDocumentMetadata", | |
| "ECGAnalysis", | |
| "RadiologyAnalysis", | |
| "LaboratoryResults", | |
| "ClinicalNotesAnalysis", | |
| "DocumentClassification", | |
| "ValidationResult", | |
| "validate_document_schema", | |
| "route_to_specialized_model" | |
| ] |