Spaces:
Sleeping
Sleeping
| """ | |
| PDF Medical Extractor - Phase 2 | |
| Structured PDF extraction using Donut/LayoutLMv3 for medical documents. | |
| This module provides specialized extraction for medical PDFs including | |
| radiology reports, laboratory results, clinical notes, and ECG reports. | |
| Author: MiniMax Agent | |
| Date: 2025-10-29 | |
| Version: 1.0.0 | |
| """ | |
| import os | |
| import json | |
| import io | |
| import logging | |
| from typing import Dict, List, Optional, Any, Tuple | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import fitz # PyMuPDF | |
| import pytesseract | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| import torch | |
| from tqdm import tqdm | |
| from medical_schemas import ( | |
| MedicalDocumentMetadata, ConfidenceScore, RadiologyAnalysis, | |
| LaboratoryResults, ClinicalNotesAnalysis, ValidationResult, | |
| validate_document_schema | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ExtractionResult: | |
| """Result of PDF extraction with confidence scoring""" | |
| raw_text: str | |
| structured_data: Dict[str, Any] | |
| confidence_scores: Dict[str, float] | |
| extraction_method: str # "donut", "ocr", "hybrid" | |
| processing_time: float | |
| tables_extracted: List[Dict[str, Any]] | |
| images_extracted: List[str] | |
| metadata: Dict[str, Any] | |
| class DonutMedicalExtractor: | |
| """Medical PDF extraction using Donut model for structured output""" | |
| def __init__(self, model_name: str = "naver-clova-ix/donut-base-finetuned-rvlcdip"): | |
| self.model_name = model_name | |
| self.processor = None | |
| self.model = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._load_model() | |
| def _load_model(self): | |
| """Load Donut model and processor""" | |
| try: | |
| logger.info(f"Loading Donut model: {self.model_name}") | |
| self.processor = DonutProcessor.from_pretrained(self.model_name) | |
| self.model = VisionEncoderDecoderModel.from_pretrained(self.model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info("Donut model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load Donut model: {str(e)}") | |
| raise | |
| def extract_from_image(self, image: Image.Image, task_prompt: str = None) -> Dict[str, Any]: | |
| """Extract structured data from image using Donut""" | |
| if task_prompt is None: | |
| task_prompt = "<s_rvlcdip>" | |
| try: | |
| # Prepare image for Donut | |
| pixel_values = self.processor(images=image, return_tensors="pt").pixel_values | |
| pixel_values = pixel_values.to(self.device) | |
| # Generate structured output | |
| task_prompt_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False, | |
| return_tensors="pt").input_ids | |
| task_prompt_ids = task_prompt_ids.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| task_prompt_ids, | |
| pixel_values, | |
| max_length=512, | |
| early_stopping=False, | |
| pad_token_id=self.processor.tokenizer.pad_token_id, | |
| eos_token_id=self.processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| ) | |
| # Decode output | |
| output_sequence = outputs.cpu().numpy()[0] | |
| decoded_output = self.processor.tokenizer.decode(output_sequence, skip_special_tokens=True) | |
| # Parse JSON from decoded output | |
| json_start = decoded_output.find('{') | |
| json_end = decoded_output.rfind('}') + 1 | |
| if json_start != -1 and json_end != -1: | |
| json_str = decoded_output[json_start:json_end] | |
| structured_data = json.loads(json_str) | |
| else: | |
| structured_data = {"raw_text": decoded_output} | |
| return structured_data | |
| except Exception as e: | |
| logger.error(f"Donut extraction error: {str(e)}") | |
| return {"raw_text": "", "error": str(e)} | |
| class MedicalPDFProcessor: | |
| """Medical PDF processing with multiple extraction methods""" | |
| def __init__(self): | |
| self.donut_extractor = None | |
| self.ocr_enabled = True | |
| # Initialize Donut extractor | |
| try: | |
| self.donut_extractor = DonutMedicalExtractor() | |
| except Exception as e: | |
| logger.warning(f"Donut extractor not available: {str(e)}") | |
| self.donut_extractor = None | |
| def process_pdf(self, pdf_path: str, document_type: str = "unknown") -> ExtractionResult: | |
| """ | |
| Process medical PDF with multiple extraction methods | |
| Args: | |
| pdf_path: Path to PDF file | |
| document_type: Type of medical document | |
| Returns: | |
| ExtractionResult with structured data | |
| """ | |
| import time | |
| start_time = time.time() | |
| try: | |
| # Open PDF and extract basic info | |
| doc = fitz.open(pdf_path) | |
| page_count = len(doc) | |
| metadata = { | |
| "page_count": page_count, | |
| "pdf_metadata": doc.metadata, | |
| "file_size": os.path.getsize(pdf_path) | |
| } | |
| # Extract text using multiple methods | |
| raw_text = "" | |
| tables = [] | |
| images = [] | |
| for page_num in range(page_count): | |
| page = doc.load_page(page_num) | |
| # Extract text | |
| page_text = page.get_text() | |
| raw_text += f"\n--- Page {page_num + 1} ---\n{page_text}" | |
| # Extract tables using different methods | |
| page_tables = self._extract_tables(page) | |
| tables.extend(page_tables) | |
| # Extract images | |
| page_images = self._extract_images(page, pdf_path, page_num) | |
| images.extend(page_images) | |
| doc.close() | |
| # Determine extraction method based on content | |
| extraction_method = self._determine_extraction_method(raw_text, document_type) | |
| # Extract structured data based on document type | |
| if extraction_method == "donut" and self.donut_extractor: | |
| structured_data = self._extract_with_donut(pdf_path, document_type) | |
| else: | |
| structured_data = self._extract_with_fallback(raw_text, document_type) | |
| # Calculate confidence scores | |
| confidence_scores = self._calculate_extraction_confidence( | |
| raw_text, structured_data, tables, images | |
| ) | |
| processing_time = time.time() - start_time | |
| return ExtractionResult( | |
| raw_text=raw_text, | |
| structured_data=structured_data, | |
| confidence_scores=confidence_scores, | |
| extraction_method=extraction_method, | |
| processing_time=processing_time, | |
| tables_extracted=tables, | |
| images_extracted=images, | |
| metadata=metadata | |
| ) | |
| except Exception as e: | |
| logger.error(f"PDF processing error: {str(e)}") | |
| return ExtractionResult( | |
| raw_text="", | |
| structured_data={"error": str(e)}, | |
| confidence_scores={"overall": 0.0}, | |
| extraction_method="error", | |
| processing_time=time.time() - start_time, | |
| tables_extracted=[], | |
| images_extracted=[], | |
| metadata={"error": str(e)} | |
| ) | |
| def _determine_extraction_method(self, text: str, document_type: str) -> str: | |
| """Determine best extraction method based on content and type""" | |
| # High confidence cases for Donut | |
| if document_type in ["radiology", "ecg_report"] and len(text) > 500: | |
| return "donut" | |
| # Check for structured content indicators | |
| structured_indicators = [ | |
| "findings:", "impression:", "technique:", "results:", | |
| "normal ranges:", "reference values:", "patient information:" | |
| ] | |
| indicator_count = sum(1 for indicator in structured_indicators if indicator.lower() in text.lower()) | |
| if indicator_count >= 3 and len(text) > 1000: | |
| return "donut" | |
| # Fallback to text-based extraction | |
| return "fallback" | |
| def _extract_with_donut(self, pdf_path: str, document_type: str) -> Dict[str, Any]: | |
| """Extract structured data using Donut model""" | |
| if not self.donut_extractor: | |
| return self._extract_with_fallback("", document_type) | |
| try: | |
| # Convert PDF to images (first page for now, can be extended) | |
| images = self._pdf_to_images(pdf_path) | |
| if not images: | |
| return self._extract_with_fallback("", document_type) | |
| # Define task prompt based on document type | |
| task_prompts = { | |
| "radiology": "<s_radiology_report>", | |
| "laboratory": "<s_laboratory_report>", | |
| "clinical_notes": "<s_clinical_note>", | |
| "ecg_report": "<s_ecg_report>", | |
| "unknown": "<s_medical_document>" | |
| } | |
| task_prompt = task_prompts.get(document_type, "<s_medical_document>") | |
| # Extract using Donut | |
| structured_data = self.donut_extractor.extract_from_image(images[0], task_prompt) | |
| # Post-process based on document type | |
| if document_type == "radiology": | |
| structured_data = self._postprocess_radiology(structured_data) | |
| elif document_type == "laboratory": | |
| structured_data = self._postprocess_laboratory(structured_data) | |
| elif document_type == "clinical_notes": | |
| structured_data = self._postprocess_clinical_notes(structured_data) | |
| elif document_type == "ecg_report": | |
| structured_data = self._postprocess_ecg(structured_data) | |
| return structured_data | |
| except Exception as e: | |
| logger.error(f"Donut extraction error: {str(e)}") | |
| return self._extract_with_fallback("", document_type) | |
| def _extract_with_fallback(self, text: str, document_type: str) -> Dict[str, Any]: | |
| """Fallback extraction using text processing and OCR if needed""" | |
| try: | |
| # Basic text cleaning | |
| cleaned_text = text.strip() | |
| # Document-type specific extraction | |
| if document_type == "radiology": | |
| return self._extract_radiology_from_text(cleaned_text) | |
| elif document_type == "laboratory": | |
| return self._extract_laboratory_from_text(cleaned_text) | |
| elif document_type == "clinical_notes": | |
| return self._extract_clinical_notes_from_text(cleaned_text) | |
| elif document_type == "ecg_report": | |
| return self._extract_ecg_from_text(cleaned_text) | |
| else: | |
| return { | |
| "raw_text": cleaned_text, | |
| "document_type": document_type, | |
| "extraction_method": "fallback_text" | |
| } | |
| except Exception as e: | |
| logger.error(f"Fallback extraction error: {str(e)}") | |
| return {"raw_text": text, "error": str(e), "extraction_method": "fallback"} | |
| def _extract_radiology_from_text(self, text: str) -> Dict[str, Any]: | |
| """Extract radiology information from text""" | |
| lines = text.split('\n') | |
| findings = [] | |
| impression = [] | |
| technique = [] | |
| current_section = None | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| line_lower = line.lower() | |
| if any(keyword in line_lower for keyword in ["findings:", "findings"]): | |
| current_section = "findings" | |
| continue | |
| elif any(keyword in line_lower for keyword in ["impression:", "impression", "conclusion:"]): | |
| current_section = "impression" | |
| continue | |
| elif any(keyword in line_lower for keyword in ["technique:", "protocol:"]): | |
| current_section = "technique" | |
| continue | |
| if current_section == "findings": | |
| findings.append(line) | |
| elif current_section == "impression": | |
| impression.append(line) | |
| elif current_section == "technique": | |
| technique.append(line) | |
| return { | |
| "findings": " ".join(findings), | |
| "impression": " ".join(impression), | |
| "technique": " ".join(technique), | |
| "document_type": "radiology", | |
| "extraction_method": "text_pattern_matching" | |
| } | |
| def _extract_laboratory_from_text(self, text: str) -> Dict[str, Any]: | |
| """Extract laboratory results from text""" | |
| lines = text.split('\n') | |
| tests = [] | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Look for test patterns | |
| # Pattern: Test Name Value Units Reference Range Flag | |
| parts = line.split() | |
| if len(parts) >= 3: | |
| # Try to identify test components | |
| test_data = { | |
| "raw_line": line, | |
| "potential_test": parts[0] if len(parts) > 0 else "", | |
| "potential_value": parts[1] if len(parts) > 1 else "", | |
| "potential_unit": parts[2] if len(parts) > 2 else "", | |
| } | |
| tests.append(test_data) | |
| return { | |
| "tests": tests, | |
| "document_type": "laboratory", | |
| "extraction_method": "text_pattern_matching" | |
| } | |
| def _extract_clinical_notes_from_text(self, text: str) -> Dict[str, Any]: | |
| """Extract clinical notes sections from text""" | |
| lines = text.split('\n') | |
| sections = {} | |
| current_section = "general" | |
| for line in lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| line_lower = line.lower() | |
| # Identify section headers | |
| if any(keyword in line_lower for keyword in ["chief complaint:", "chief complaint", "cc:"]): | |
| current_section = "chief_complaint" | |
| continue | |
| elif any(keyword in line_lower for keyword in ["history of present illness:", "hpi:", "history:"]): | |
| current_section = "history_present_illness" | |
| continue | |
| elif any(keyword in line_lower for keyword in ["assessment:", "diagnosis:", "impression:"]): | |
| current_section = "assessment" | |
| continue | |
| elif any(keyword in line_lower for keyword in ["plan:", "treatment:", "recommendations:"]): | |
| current_section = "plan" | |
| continue | |
| # Add line to current section | |
| if current_section not in sections: | |
| sections[current_section] = [] | |
| sections[current_section].append(line) | |
| # Convert lists to text | |
| for section in sections: | |
| sections[section] = " ".join(sections[section]) | |
| return { | |
| "sections": sections, | |
| "document_type": "clinical_notes", | |
| "extraction_method": "text_pattern_matching" | |
| } | |
| def _extract_ecg_from_text(self, text: str) -> Dict[str, Any]: | |
| """Extract ECG information from text""" | |
| lines = text.split('\n') | |
| ecg_data = {} | |
| for line in lines: | |
| line = line.strip().lower() | |
| # Extract ECG measurements | |
| if "heart rate" in line or "hr" in line: | |
| import re | |
| hr_match = re.search(r'(\d+)', line) | |
| if hr_match: | |
| ecg_data["heart_rate"] = int(hr_match.group(1)) | |
| if "rhythm" in line: | |
| ecg_data["rhythm"] = line | |
| if any(interval in line for interval in ["pr interval", "qrs", "qt"]): | |
| ecg_data[line.split(':')[0]] = line | |
| return { | |
| "ecg_data": ecg_data, | |
| "document_type": "ecg_report", | |
| "extraction_method": "text_pattern_matching" | |
| } | |
| def _postprocess_radiology(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Post-process radiology extraction results""" | |
| # Ensure required fields exist | |
| if "findings" not in data: | |
| data["findings"] = "" | |
| if "impression" not in data: | |
| data["impression"] = "" | |
| data["document_type"] = "radiology" | |
| return data | |
| def _postprocess_laboratory(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Post-process laboratory extraction results""" | |
| # Ensure tests array exists | |
| if "tests" not in data: | |
| data["tests"] = [] | |
| data["document_type"] = "laboratory" | |
| return data | |
| def _postprocess_clinical_notes(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Post-process clinical notes extraction results""" | |
| # Ensure sections exist | |
| if "sections" not in data: | |
| data["sections"] = {} | |
| data["document_type"] = "clinical_notes" | |
| return data | |
| def _postprocess_ecg(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Post-process ECG extraction results""" | |
| # Ensure ecg_data exists | |
| if "ecg_data" not in data: | |
| data["ecg_data"] = {} | |
| data["document_type"] = "ecg_report" | |
| return data | |
| def _pdf_to_images(self, pdf_path: str) -> List[Image.Image]: | |
| """Convert PDF pages to images for Donut processing""" | |
| images = [] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| for page_num in range(min(3, len(doc))): # Process first 3 pages | |
| page = doc.load_page(page_num) | |
| mat = fitz.Matrix(2.0, 2.0) # 2x zoom for better OCR | |
| pix = page.get_pixmap(matrix=mat) | |
| img_data = pix.tobytes("png") | |
| image = Image.open(io.BytesIO(img_data)) | |
| images.append(image) | |
| doc.close() | |
| except Exception as e: | |
| logger.error(f"PDF to image conversion error: {str(e)}") | |
| return images | |
| def _extract_tables(self, page) -> List[Dict[str, Any]]: | |
| """Extract tables from PDF page""" | |
| tables = [] | |
| try: | |
| # Use PyMuPDF table extraction if available | |
| tables_data = page.find_tables() | |
| for table in tables_data: | |
| table_dict = table.extract() | |
| tables.append({ | |
| "rows": len(table_dict), | |
| "columns": len(table_dict[0]) if table_dict else 0, | |
| "data": table_dict | |
| }) | |
| except Exception as e: | |
| logger.debug(f"Table extraction failed: {str(e)}") | |
| return tables | |
| def _extract_images(self, page, pdf_path: str, page_num: int) -> List[str]: | |
| """Extract images from PDF page""" | |
| images = [] | |
| try: | |
| image_list = page.get_images() | |
| for img_index, img in enumerate(image_list): | |
| xref = img[0] | |
| pix = fitz.Pixmap(page.parent, xref) | |
| if pix.n - pix.alpha < 4: # GRAY or RGB | |
| img_path = f"{Path(pdf_path).stem}_page{page_num+1}_img{img_index+1}.png" | |
| pix.save(img_path) | |
| images.append(img_path) | |
| pix = None | |
| except Exception as e: | |
| logger.debug(f"Image extraction failed: {str(e)}") | |
| return images | |
| def _calculate_extraction_confidence(self, raw_text: str, structured_data: Dict[str, Any], | |
| tables: List[Dict], images: List[str]) -> Dict[str, float]: | |
| """Calculate confidence scores for extraction quality""" | |
| confidence_scores = {} | |
| # Text extraction confidence | |
| text_length = len(raw_text.strip()) | |
| confidence_scores["text_extraction"] = min(1.0, text_length / 1000) if text_length > 0 else 0.0 | |
| # Structured data completeness | |
| required_fields = 0 | |
| present_fields = 0 | |
| if "findings" in structured_data or "impression" in structured_data: | |
| required_fields += 1 | |
| if structured_data.get("findings") or structured_data.get("impression"): | |
| present_fields += 1 | |
| if "tests" in structured_data: | |
| required_fields += 1 | |
| if structured_data.get("tests"): | |
| present_fields += 1 | |
| if "sections" in structured_data: | |
| required_fields += 1 | |
| if structured_data.get("sections"): | |
| present_fields += 1 | |
| confidence_scores["structural_completeness"] = present_fields / max(required_fields, 1) | |
| # Table extraction confidence | |
| confidence_scores["table_extraction"] = min(1.0, len(tables) * 0.3) | |
| # Image extraction confidence | |
| confidence_scores["image_extraction"] = min(1.0, len(images) * 0.2) | |
| # Overall confidence (weighted average) | |
| overall = ( | |
| 0.4 * confidence_scores["text_extraction"] + | |
| 0.4 * confidence_scores["structural_completeness"] + | |
| 0.1 * confidence_scores["table_extraction"] + | |
| 0.1 * confidence_scores["image_extraction"] | |
| ) | |
| confidence_scores["overall"] = overall | |
| return confidence_scores | |
| def convert_to_schema_format(self, extraction_result: ExtractionResult, | |
| document_type: str) -> Optional[Dict[str, Any]]: | |
| """Convert extraction result to canonical schema format""" | |
| try: | |
| # Create metadata | |
| metadata = MedicalDocumentMetadata( | |
| source_type=document_type, | |
| data_completeness=extraction_result.confidence_scores.get("overall", 0.0) | |
| ) | |
| # Create confidence score | |
| confidence = ConfidenceScore( | |
| extraction_confidence=extraction_result.confidence_scores.get("overall", 0.0), | |
| model_confidence=0.8, # Default assumption | |
| data_quality=extraction_result.confidence_scores.get("text_extraction", 0.0) | |
| ) | |
| # Convert based on document type | |
| if document_type == "radiology": | |
| return self._convert_to_radiology_schema(extraction_result, metadata, confidence) | |
| elif document_type == "laboratory": | |
| return self._convert_to_laboratory_schema(extraction_result, metadata, confidence) | |
| elif document_type == "clinical_notes": | |
| return self._convert_to_clinical_notes_schema(extraction_result, metadata, confidence) | |
| else: | |
| return None | |
| except Exception as e: | |
| logger.error(f"Schema conversion error: {str(e)}") | |
| return None | |
| def _convert_to_radiology_schema(self, result: ExtractionResult, metadata: MedicalDocumentMetadata, | |
| confidence: ConfidenceScore) -> Dict[str, Any]: | |
| """Convert to radiology schema format""" | |
| data = result.structured_data | |
| return { | |
| "metadata": metadata.dict(), | |
| "image_references": [], | |
| "findings": { | |
| "findings_text": data.get("findings", ""), | |
| "impression_text": data.get("impression", ""), | |
| "technique_description": data.get("technique", "") | |
| }, | |
| "segmentations": [], | |
| "metrics": {}, | |
| "confidence": confidence.dict(), | |
| "criticality_level": "routine", | |
| "follow_up_recommendations": [] | |
| } | |
| def _convert_to_laboratory_schema(self, result: ExtractionResult, metadata: MedicalDocumentMetadata, | |
| confidence: ConfidenceScore) -> Dict[str, Any]: | |
| """Convert to laboratory schema format""" | |
| data = result.structured_data | |
| return { | |
| "metadata": metadata.dict(), | |
| "tests": data.get("tests", []), | |
| "confidence": confidence.dict(), | |
| "critical_values": [], | |
| "abnormal_count": 0, | |
| "critical_count": 0 | |
| } | |
| def _convert_to_clinical_notes_schema(self, result: ExtractionResult, metadata: MedicalDocumentMetadata, | |
| confidence: ConfidenceScore) -> Dict[str, Any]: | |
| """Convert to clinical notes schema format""" | |
| data = result.structured_data | |
| sections = data.get("sections", {}) | |
| return { | |
| "metadata": metadata.dict(), | |
| "sections": [{"section_type": k, "content": v, "confidence": 0.8} for k, v in sections.items()], | |
| "entities": [], | |
| "confidence": confidence.dict() | |
| } | |
| # Export main classes | |
| __all__ = [ | |
| "MedicalPDFProcessor", | |
| "DonutMedicalExtractor", | |
| "ExtractionResult" | |
| ] |