| | from dotenv import load_dotenv |
| | load_dotenv() |
| |
|
| | from fastapi import FastAPI, UploadFile, File, Form |
| | from fastapi.responses import JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from transformers import pipeline as hf_pipeline, AutoTokenizer, AutoModelForTokenClassification |
| | from doctr.io import DocumentFile |
| | from doctr.models import ocr_predictor |
| | from img2table.document import Image as Img2TableImage |
| | from img2table.ocr import DocTR |
| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | import io |
| | import json |
| | import os |
| | import tempfile |
| | import base64 |
| | from typing import Dict, Any, Optional, List |
| | import difflib |
| | import re |
| | import httpx |
| | from bs4 import BeautifulSoup |
| |
|
| | |
| | from docling.document_converter import DocumentConverter, InputFormat, ImageFormatOption |
| | from docling.datamodel.pipeline_options import PdfPipelineOptions |
| | from docling_ocr_onnxtr import OnnxtrOcrOptions |
| |
|
| | |
| | from router_chat import router as chat_router |
| | from faq_store import initialize_faq_store |
| |
|
| | app = FastAPI(title="ScanAssured OCR & NER API") |
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | initialize_faq_store() |
| |
|
| | app.include_router(chat_router) |
| |
|
| | |
| | DRUG_INTERACTIONS = {} |
| | interactions_path = os.path.join(os.path.dirname(__file__), 'interactions_data.json') |
| | if os.path.exists(interactions_path): |
| | with open(interactions_path, 'r') as f: |
| | DRUG_INTERACTIONS = json.load(f) |
| | print(f"Loaded {len(DRUG_INTERACTIONS)} drug interaction entries") |
| |
|
| | |
| | MEDLINEPLUS_MAP = {} |
| | medlineplus_map_path = os.path.join(os.path.dirname(__file__), 'medlineplus_map.json') |
| | if os.path.exists(medlineplus_map_path): |
| | with open(medlineplus_map_path, 'r') as f: |
| | MEDLINEPLUS_MAP = json.load(f) |
| | print(f"Loaded {len(MEDLINEPLUS_MAP)} MedlinePlus test mappings") |
| |
|
| | MEDLINEPLUS_CACHE = {} |
| | medlineplus_cache_path = os.path.join(os.path.dirname(__file__), 'medlineplus_cache.json') |
| | if os.path.exists(medlineplus_cache_path): |
| | with open(medlineplus_cache_path, 'r') as f: |
| | MEDLINEPLUS_CACHE = json.load(f) |
| | print(f"Loaded {len(MEDLINEPLUS_CACHE)} MedlinePlus cached entries") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | OCR_PRESETS = { |
| | "high_accuracy": { |
| | "det": "db_resnet50", |
| | "reco": "crnn_vgg16_bn", |
| | "name": "High Accuracy", |
| | "description": "Best quality, slower processing" |
| | }, |
| | "balanced": { |
| | "det": "db_resnet50", |
| | "reco": "crnn_mobilenet_v3_small", |
| | "name": "Balanced (Recommended)", |
| | "description": "Good quality and speed" |
| | }, |
| | "fast": { |
| | "det": "db_mobilenet_v3_large", |
| | "reco": "crnn_mobilenet_v3_small", |
| | "name": "Fast", |
| | "description": "Fastest processing, slightly lower quality" |
| | }, |
| | } |
| |
|
| | OCR_DETECTION_MODELS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18"] |
| | OCR_RECOGNITION_MODELS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "parseq"] |
| |
|
| | |
| | NER_MODELS = { |
| | "Clinical-AI-Apollo/Medical-NER": { |
| | "name": "Medical NER (Recommended)", |
| | "description": "Medications, diseases, lab values, procedures, dosages", |
| | "entities": ["MEDICATION", "DOSAGE", "FREQUENCY", "DURATION", |
| | "DISEASE_DISORDER", "SIGN_SYMPTOM", "DIAGNOSTIC_PROCEDURE", |
| | "THERAPEUTIC_PROCEDURE", "LAB_VALUE", "SEVERITY"] |
| | }, |
| | "samrawal/bert-base-uncased_clinical-ner": { |
| | "name": "Clinical Notes", |
| | "description": "Optimized for clinical/medical notes", |
| | "entities": ["PROBLEM", "TREATMENT", "TEST"] |
| | }, |
| | } |
| |
|
| | |
| | ner_model_cache: Dict[str, Any] = {} |
| | ocr_model_cache: Dict[str, Any] = {} |
| | |
| | docling_converter_cache: Dict[str, Any] = {} |
| |
|
| | def get_docling_converter(det_arch: str = "db_mobilenet_v3_large", reco_arch: str = "crnn_vgg16_bn"): |
| | """Get or create a cached Docling DocumentConverter with OnnxTR OCR.""" |
| | cache_key = f"docling_{det_arch}_{reco_arch}" |
| |
|
| | if cache_key in docling_converter_cache: |
| | print(f"Using cached Docling converter: {cache_key}") |
| | return docling_converter_cache[cache_key] |
| |
|
| | try: |
| | print(f"Initializing Docling converter: det={det_arch}, reco={reco_arch}...") |
| |
|
| | ocr_options = OnnxtrOcrOptions( |
| | det_arch=det_arch, |
| | reco_arch=reco_arch, |
| | ) |
| |
|
| | pipeline_options = PdfPipelineOptions(ocr_options=ocr_options) |
| | pipeline_options.do_table_structure = True |
| | pipeline_options.do_ocr = True |
| | pipeline_options.allow_external_plugins = True |
| |
|
| | converter = DocumentConverter( |
| | format_options={ |
| | InputFormat.IMAGE: ImageFormatOption(pipeline_options=pipeline_options) |
| | } |
| | ) |
| |
|
| | docling_converter_cache[cache_key] = converter |
| | print(f"Docling converter {cache_key} initialized successfully!") |
| | return converter |
| | except Exception as e: |
| | print(f"ERROR: Failed to initialize Docling converter: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| |
|
| |
|
| | def run_docling_pipeline(file_content: bytes) -> Dict[str, Any]: |
| | """ |
| | Run the Docling pipeline on raw image bytes. |
| | Returns structured results for comparison with docTR. |
| | """ |
| | try: |
| | converter = get_docling_converter() |
| | if converter is None: |
| | return {"error": "Docling converter not available", "success": False} |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| | tmp_file.write(file_content) |
| | tmp_path = tmp_file.name |
| |
|
| | try: |
| | print("Running Docling pipeline...") |
| | result = converter.convert(source=tmp_path) |
| |
|
| | |
| | markdown_text = result.document.export_to_markdown() |
| |
|
| | |
| | if hasattr(result.document, 'export_to_text'): |
| | plain_text = result.document.export_to_text() |
| | else: |
| | plain_text = markdown_text |
| |
|
| | |
| | docling_tables = [] |
| | if hasattr(result.document, 'tables') and result.document.tables: |
| | for table in result.document.tables: |
| | table_data = _parse_docling_table(table) |
| | if table_data: |
| | docling_tables.append(table_data) |
| |
|
| | print(f"Docling: {len(markdown_text)} chars markdown, {len(docling_tables)} tables") |
| |
|
| | return { |
| | "success": True, |
| | "markdown_text": markdown_text, |
| | "plain_text": plain_text, |
| | "tables": docling_tables, |
| | "primary_table": docling_tables[0] if docling_tables else None, |
| | } |
| | finally: |
| | try: |
| | os.unlink(tmp_path) |
| | except: |
| | pass |
| |
|
| | except Exception as e: |
| | print(f"Docling pipeline error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return {"error": str(e), "success": False} |
| |
|
| |
|
| | def _parse_docling_table(table) -> Optional[Dict]: |
| | """Parse a Docling table into {cells, num_rows, num_columns} format.""" |
| | try: |
| | if hasattr(table, 'export_to_dataframe'): |
| | df = table.export_to_dataframe() |
| | if df is not None and not df.empty: |
| | cells = [] |
| | header = [str(col) if col is not None else '' for col in df.columns.tolist()] |
| | cells.append(header) |
| | for _, row in df.iterrows(): |
| | row_cells = [str(val).strip() if val is not None else '' for val in row.tolist()] |
| | cells.append(row_cells) |
| |
|
| | return { |
| | "cells": cells, |
| | "num_rows": len(cells), |
| | "num_columns": len(header), |
| | "method": "docling_tableformer" |
| | } |
| |
|
| | if hasattr(table, 'export_to_markdown'): |
| | md = table.export_to_markdown() |
| | if md: |
| | return { |
| | "cells": [], |
| | "num_rows": 0, |
| | "num_columns": 0, |
| | "method": "docling_tableformer", |
| | "markdown": md |
| | } |
| |
|
| | return None |
| | except Exception as e: |
| | print(f"Docling table parse error: {e}") |
| | return None |
| |
|
| |
|
| | |
| | def get_ocr_predictor(det_arch: str, reco_arch: str): |
| | """Retrieves a loaded OCR predictor from cache or loads it if necessary.""" |
| | cache_key = f"{det_arch}_{reco_arch}" |
| |
|
| | if cache_key in ocr_model_cache: |
| | print(f"Using cached OCR model: {cache_key}") |
| | return ocr_model_cache[cache_key] |
| |
|
| | try: |
| | print(f"Loading OCR model: det={det_arch}, reco={reco_arch}...") |
| | predictor = ocr_predictor( |
| | det_arch=det_arch, |
| | reco_arch=reco_arch, |
| | pretrained=True, |
| | assume_straight_pages=True, |
| | straighten_pages=False, |
| | detect_orientation=False, |
| | preserve_aspect_ratio=True |
| | ) |
| | ocr_model_cache[cache_key] = predictor |
| | print(f"OCR model {cache_key} loaded successfully!") |
| | return predictor |
| | except Exception as e: |
| | print(f"ERROR: Failed to load OCR model {cache_key}: {e}") |
| | return None |
| |
|
| | |
| | def get_ner_pipeline(model_id: str): |
| | """Retrieves a loaded NER pipeline from cache or loads it if necessary.""" |
| | if model_id not in NER_MODELS: |
| | raise ValueError(f"Unknown NER model ID: {model_id}") |
| |
|
| | if model_id in ner_model_cache: |
| | print(f"Using cached NER model: {model_id}") |
| | return ner_model_cache[model_id] |
| |
|
| | try: |
| | print(f"Loading NER model: {model_id}...") |
| | tokenizer = AutoTokenizer.from_pretrained(model_id) |
| | model = AutoModelForTokenClassification.from_pretrained(model_id) |
| |
|
| | ner_pipeline = hf_pipeline( |
| | "ner", |
| | model=model, |
| | tokenizer=tokenizer, |
| | aggregation_strategy="simple" |
| | ) |
| | ner_model_cache[model_id] = ner_pipeline |
| | print(f"NER model {model_id} loaded successfully!") |
| | return ner_pipeline |
| | except Exception as e: |
| | print(f"ERROR: Failed to load NER model {model_id}: {e}") |
| | return None |
| |
|
| |
|
| | def _edit_distance(s1: str, s2: str) -> int: |
| | """Compute Levenshtein edit distance between two strings.""" |
| | if len(s1) < len(s2): |
| | return _edit_distance(s2, s1) |
| | if len(s2) == 0: |
| | return len(s1) |
| |
|
| | prev_row = range(len(s2) + 1) |
| | for i, c1 in enumerate(s1): |
| | curr_row = [i + 1] |
| | for j, c2 in enumerate(s2): |
| | insertions = prev_row[j + 1] + 1 |
| | deletions = curr_row[j] + 1 |
| | substitutions = prev_row[j] + (c1 != c2) |
| | curr_row.append(min(insertions, deletions, substitutions)) |
| | prev_row = curr_row |
| | return prev_row[-1] |
| |
|
| |
|
| | |
| |
|
| | _entity_dicts: dict[str, set] = {} |
| |
|
| |
|
| | def _build_entity_dicts(): |
| | """Build per-entity-type dictionaries from already-loaded DRUG_INTERACTIONS and MEDLINEPLUS_MAP.""" |
| | global _entity_dicts |
| |
|
| | med_dict: set[str] = set() |
| | for drug_name in DRUG_INTERACTIONS.keys(): |
| | for part in str(drug_name).split(','): |
| | part = part.strip().lower() |
| | if len(part) >= 4: |
| | med_dict.add(part) |
| |
|
| | lab_dict: set[str] = set() |
| | for test_name, data in MEDLINEPLUS_MAP.items(): |
| | if len(test_name) >= 4: |
| | lab_dict.add(test_name.lower()) |
| | for alias in data.get('aliases', []): |
| | if len(alias) >= 4: |
| | lab_dict.add(alias.lower()) |
| |
|
| | _entity_dicts = { |
| | 'MEDICATION': med_dict, |
| | 'LAB_VALUE': lab_dict, |
| | 'DIAGNOSTIC_PROCEDURE': lab_dict, |
| | 'TREATMENT': med_dict, |
| | 'CHEM': med_dict, |
| | 'CHEMICAL': med_dict, |
| | } |
| | print(f"Entity dicts built: {len(med_dict)} medication terms, {len(lab_dict)} lab terms") |
| |
|
| |
|
| | def _find_closest(word: str, dictionary: set) -> tuple: |
| | best_match, best_dist = None, 999 |
| | word_lower = word.lower() |
| | for term in dictionary: |
| | if abs(len(term) - len(word_lower)) > 3: |
| | continue |
| | dist = _edit_distance(word_lower, term) |
| | if dist < best_dist: |
| | best_dist = dist |
| | best_match = term |
| | return best_match, best_dist |
| |
|
| |
|
| | def _match_case(original: str, replacement: str) -> str: |
| | if original.isupper(): |
| | return replacement.upper() |
| | if original[0].isupper(): |
| | return replacement.capitalize() |
| | return replacement.lower() |
| |
|
| |
|
| | def correct_with_ner_entities( |
| | words_with_boxes: list, |
| | ner_entities: list, |
| | text: str, |
| | confidence_threshold: float = 0.75, |
| | ) -> dict: |
| | """Second-pass correction using NER entity labels as context.""" |
| | if not _entity_dicts: |
| | _build_entity_dicts() |
| |
|
| | word_conf: dict[str, float] = {} |
| | for w in words_with_boxes: |
| | key = w['word'].lower() |
| | word_conf[key] = min(word_conf.get(key, 1.0), w.get('confidence', 1.0)) |
| |
|
| | corrections = [] |
| | corrected_text = text |
| |
|
| | for entity in ner_entities: |
| | entity_type = entity.get('entity_group', '') |
| | entity_word = entity.get('word', '').strip() |
| | lookup_dict = _entity_dicts.get(entity_type) |
| | if not lookup_dict or not entity_word: |
| | continue |
| |
|
| | for token in entity_word.split(): |
| | clean_token = re.sub(r'[^a-zA-Z]', '', token) |
| | if not clean_token.isalpha() or len(clean_token) < 4: |
| | continue |
| |
|
| | ocr_conf = word_conf.get(clean_token.lower(), 1.0) |
| | if ocr_conf >= confidence_threshold: |
| | continue |
| |
|
| | best_match, best_dist = _find_closest(clean_token, lookup_dict) |
| | if best_match is None or best_dist > 2: |
| | continue |
| | if best_match.lower() == clean_token.lower(): |
| | continue |
| |
|
| | replacement = _match_case(clean_token, best_match) |
| | match = re.search(r'\b' + re.escape(clean_token) + r'\b', |
| | corrected_text, re.IGNORECASE) |
| | if not match: |
| | continue |
| |
|
| | start, end = match.start(), match.end() |
| | corrected_text = corrected_text[:start] + replacement + corrected_text[end:] |
| | corrections.append({ |
| | 'original': clean_token, |
| | 'corrected': replacement, |
| | 'confidence': round(1.0 - best_dist / max(len(clean_token), len(best_match)), 4), |
| | 'ocr_confidence': round(ocr_conf, 4), |
| | 'edit_distance': best_dist, |
| | 'source': 'ner', |
| | 'entity_type': entity_type, |
| | }) |
| | word_conf[replacement.lower()] = 1.0 |
| |
|
| | return {'corrected_text': corrected_text, 'corrections': corrections} |
| |
|
| |
|
| | |
| | def deskew_image(image: np.ndarray) -> np.ndarray: |
| | """Deskew image using projection profile method.""" |
| | try: |
| | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image |
| | edges = cv2.Canny(gray, 50, 150, apertureSize=3) |
| | lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100, minLineLength=100, maxLineGap=10) |
| |
|
| | if lines is not None and len(lines) > 0: |
| | angles = [] |
| | for line in lines: |
| | x1, y1, x2, y2 = line[0] |
| | angle = np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi |
| | if abs(angle) < 45: |
| | angles.append(angle) |
| |
|
| | if angles: |
| | median_angle = np.median(angles) |
| | if abs(median_angle) > 0.5: |
| | (h, w) = image.shape[:2] |
| | center = (w // 2, h // 2) |
| | M = cv2.getRotationMatrix2D(center, median_angle, 1.0) |
| | rotated = cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) |
| | return rotated |
| |
|
| | return image |
| | except Exception as e: |
| | print(f"Deskew warning: {e}") |
| | return image |
| |
|
| | def preprocess_for_doctr(file_content: bytes) -> np.ndarray: |
| | """Automatic preprocessing pipeline optimized for docTR.""" |
| | nparr = np.frombuffer(file_content, np.uint8) |
| | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| |
|
| | if img is None: |
| | raise ValueError("Failed to decode image") |
| |
|
| | img = deskew_image(img) |
| |
|
| | lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) |
| | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| | lab[:, :, 0] = clahe.apply(lab[:, :, 0]) |
| | img = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) |
| |
|
| | img = cv2.fastNlMeansDenoisingColored(img, None, 6, 6, 7, 21) |
| | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
|
| | return img |
| |
|
| | def basic_cleanup(text: str) -> str: |
| | """Clean up OCR text for NER processing.""" |
| | text = " ".join(text.split()) |
| | return text |
| |
|
| |
|
| | |
| |
|
| | |
| | img2table_ocr_cache = {} |
| |
|
| | def get_img2table_ocr(): |
| | """Get or create img2table DocTR OCR instance.""" |
| | if 'doctr' not in img2table_ocr_cache: |
| | img2table_ocr_cache['doctr'] = DocTR() |
| | return img2table_ocr_cache['doctr'] |
| |
|
| |
|
| | def extract_tables_with_img2table(image_bytes: bytes, img_width: int, img_height: int) -> dict: |
| | """ |
| | Use img2table to detect and extract table structure from image. |
| | Returns table data with properly structured cells. |
| | """ |
| | try: |
| | |
| | with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| | tmp_file.write(image_bytes) |
| | tmp_path = tmp_file.name |
| |
|
| | |
| | img2table_img = Img2TableImage(src=tmp_path) |
| |
|
| | |
| | ocr = get_img2table_ocr() |
| |
|
| | |
| | tables = img2table_img.extract_tables( |
| | ocr=ocr, |
| | implicit_rows=True, |
| | implicit_columns=True, |
| | borderless_tables=True, |
| | min_confidence=50 |
| | ) |
| |
|
| | |
| | try: |
| | os.unlink(tmp_path) |
| | except: |
| | pass |
| |
|
| | if not tables: |
| | return {'is_table': False, 'tables': []} |
| |
|
| | |
| | all_tables = [] |
| | for table in tables: |
| | cells = [] |
| |
|
| | |
| | if hasattr(table, 'df') and table.df is not None: |
| | df = table.df |
| | |
| | |
| | header = [str(col) if col is not None else '' for col in df.columns.tolist()] |
| | cells.append(header) |
| | |
| | for _, row in df.iterrows(): |
| | row_cells = [str(val).strip() if val is not None else '' for val in row.tolist()] |
| | cells.append(row_cells) |
| |
|
| | |
| | elif hasattr(table, 'content') and table.content is not None: |
| | content = table.content |
| | if isinstance(content, list): |
| | for row in content: |
| | if isinstance(row, (list, tuple)): |
| | row_cells = [] |
| | for cell in row: |
| | if cell is None: |
| | row_cells.append('') |
| | elif isinstance(cell, str): |
| | row_cells.append(cell.strip()) |
| | elif hasattr(cell, 'value'): |
| | row_cells.append(str(cell.value).strip() if cell.value else '') |
| | elif hasattr(cell, 'text'): |
| | row_cells.append(str(cell.text).strip() if cell.text else '') |
| | else: |
| | row_cells.append(str(cell).strip()) |
| | cells.append(row_cells) |
| | elif isinstance(row, dict): |
| | |
| | row_cells = [str(v).strip() if v else '' for v in row.values()] |
| | cells.append(row_cells) |
| |
|
| | |
| | elif hasattr(table, '_content'): |
| | print(f"Table has _content: {type(table._content)}") |
| |
|
| | |
| | if cells and len(cells) > 1: |
| | |
| | has_content = any(any(c.strip() for c in row) for row in cells) |
| | if has_content: |
| | num_cols = max(len(row) for row in cells) if cells else 0 |
| | all_tables.append({ |
| | 'cells': cells, |
| | 'num_rows': len(cells), |
| | 'num_columns': num_cols |
| | }) |
| | print(f"Extracted table with {len(cells)} rows and {num_cols} columns") |
| |
|
| | if not all_tables: |
| | print("No valid tables extracted") |
| | return {'is_table': False, 'tables': []} |
| |
|
| | |
| | primary_table = max(all_tables, key=lambda t: t['num_rows'] * t['num_columns']) |
| | print(f"Primary table: {primary_table['num_rows']}x{primary_table['num_columns']}") |
| |
|
| | return { |
| | 'is_table': True, |
| | 'cells': primary_table['cells'], |
| | 'num_rows': primary_table['num_rows'], |
| | 'num_columns': primary_table['num_columns'], |
| | 'tables': all_tables, |
| | 'total_tables': len(all_tables) |
| | } |
| |
|
| | except Exception as e: |
| | print(f"img2table extraction error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return {'is_table': False, 'error': str(e)} |
| |
|
| |
|
| | def format_table_as_markdown(table_data: dict) -> str: |
| | """Format extracted table data as a markdown table.""" |
| | if not table_data.get('is_table') or not table_data.get('cells'): |
| | return '' |
| |
|
| | cells = table_data['cells'] |
| | if not cells: |
| | return '' |
| |
|
| | num_cols = max(len(row) for row in cells) if cells else 0 |
| | if num_cols == 0: |
| | return '' |
| |
|
| | lines = [] |
| | col_widths = [3] * num_cols |
| |
|
| | |
| | normalized_cells = [] |
| | for row in cells: |
| | normalized_row = list(row) + [''] * (num_cols - len(row)) |
| | normalized_cells.append(normalized_row) |
| | for i, cell in enumerate(normalized_row): |
| | if i < num_cols: |
| | col_widths[i] = max(col_widths[i], len(str(cell))) |
| |
|
| | for row_idx, row in enumerate(normalized_cells): |
| | formatted_cells = [] |
| | for i, cell in enumerate(row): |
| | if i < num_cols: |
| | formatted_cells.append(str(cell).ljust(col_widths[i])) |
| |
|
| | line = '| ' + ' | '.join(formatted_cells) + ' |' |
| | lines.append(line) |
| |
|
| | if row_idx == 0: |
| | separator = '|' + '|'.join(['-' * (w + 2) for w in col_widths]) + '|' |
| | lines.append(separator) |
| |
|
| | return '\n'.join(lines) |
| |
|
| |
|
| | def extract_text_with_table_detection(image_bytes: bytes, img_width: int, img_height: int) -> tuple: |
| | """ |
| | Extract tables from image using img2table. |
| | Returns (markdown_text, table_data). |
| | """ |
| | table_data = extract_tables_with_img2table(image_bytes, img_width, img_height) |
| |
|
| | if table_data.get('is_table'): |
| | markdown_table = format_table_as_markdown(table_data) |
| | return markdown_table, table_data |
| | else: |
| | return '', {'is_table': False} |
| |
|
| |
|
| | |
| |
|
| | def extract_tables_two_stage(image_bytes: bytes, img_width: int, img_height: int, ocr_predictor) -> dict: |
| | """ |
| | Two-stage table extraction: |
| | 1. Detect table structure (cells/grid) WITHOUT OCR |
| | 2. Crop each cell and run docTR OCR individually |
| | |
| | This keeps multi-line text together within cells. |
| | """ |
| | try: |
| | |
| | nparr = np.frombuffer(image_bytes, np.uint8) |
| | img_array = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| | if img_array is None: |
| | return {'is_table': False, 'error': 'Failed to decode image'} |
| |
|
| | actual_height, actual_width = img_array.shape[:2] |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: |
| | tmp_file.write(image_bytes) |
| | tmp_path = tmp_file.name |
| |
|
| | |
| | img2table_img = Img2TableImage(src=tmp_path) |
| |
|
| | |
| | |
| | tables = img2table_img.extract_tables( |
| | ocr=None, |
| | implicit_rows=True, |
| | implicit_columns=True, |
| | borderless_tables=True |
| | ) |
| |
|
| | |
| | try: |
| | os.unlink(tmp_path) |
| | except: |
| | pass |
| |
|
| | if not tables: |
| | print("Two-stage: No tables detected") |
| | return {'is_table': False, 'tables': []} |
| |
|
| | |
| | all_tables = [] |
| |
|
| | for table_idx, table in enumerate(tables): |
| | print(f"Two-stage: Processing table {table_idx + 1}") |
| |
|
| | |
| | if not hasattr(table, 'bbox') or not hasattr(table, 'content'): |
| | continue |
| |
|
| | table_bbox = table.bbox |
| |
|
| | |
| | cells_data = [] |
| |
|
| | |
| | if hasattr(table, '_items') and table._items: |
| | |
| | rows_dict = {} |
| |
|
| | for cell in table._items: |
| | if hasattr(cell, 'bbox'): |
| | cell_bbox = cell.bbox |
| | row_key = cell_bbox[1] |
| |
|
| | |
| | matched_row = None |
| | for existing_row in rows_dict.keys(): |
| | if abs(existing_row - row_key) < 10: |
| | matched_row = existing_row |
| | break |
| |
|
| | if matched_row is None: |
| | matched_row = row_key |
| | rows_dict[matched_row] = [] |
| |
|
| | rows_dict[matched_row].append({ |
| | 'bbox': cell_bbox, |
| | 'x': cell_bbox[0] |
| | }) |
| |
|
| | |
| | sorted_rows = sorted(rows_dict.items(), key=lambda x: x[0]) |
| |
|
| | |
| | table_cells = [] |
| |
|
| | for row_y, row_cells in sorted_rows: |
| | |
| | row_cells.sort(key=lambda c: c['x']) |
| | row_texts = [] |
| |
|
| | for cell_info in row_cells: |
| | bbox = cell_info['bbox'] |
| | x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) |
| |
|
| | |
| | padding = 2 |
| | x1 = max(0, x1 - padding) |
| | y1 = max(0, y1 - padding) |
| | x2 = min(actual_width, x2 + padding) |
| | y2 = min(actual_height, y2 + padding) |
| |
|
| | |
| | cell_img = img_array[y1:y2, x1:x2] |
| |
|
| | if cell_img.size == 0: |
| | row_texts.append('') |
| | continue |
| |
|
| | |
| | cell_img_rgb = cv2.cvtColor(cell_img, cv2.COLOR_BGR2RGB) |
| |
|
| | |
| | cell_text = ocr_single_cell(cell_img_rgb, ocr_predictor) |
| | row_texts.append(cell_text) |
| |
|
| | if row_texts: |
| | table_cells.append(row_texts) |
| |
|
| | if table_cells: |
| | num_cols = max(len(row) for row in table_cells) |
| | |
| | normalized_cells = [] |
| | for row in table_cells: |
| | normalized_row = row + [''] * (num_cols - len(row)) |
| | normalized_cells.append(normalized_row) |
| |
|
| | all_tables.append({ |
| | 'cells': normalized_cells, |
| | 'num_rows': len(normalized_cells), |
| | 'num_columns': num_cols, |
| | 'method': 'two_stage' |
| | }) |
| | print(f"Two-stage: Extracted {len(normalized_cells)}x{num_cols} table") |
| |
|
| | if not all_tables: |
| | return {'is_table': False, 'tables': []} |
| |
|
| | |
| | primary_table = max(all_tables, key=lambda t: t['num_rows'] * t['num_columns']) |
| |
|
| | return { |
| | 'is_table': True, |
| | 'cells': primary_table['cells'], |
| | 'num_rows': primary_table['num_rows'], |
| | 'num_columns': primary_table['num_columns'], |
| | 'tables': all_tables, |
| | 'total_tables': len(all_tables), |
| | 'method': 'two_stage' |
| | } |
| |
|
| | except Exception as e: |
| | print(f"Two-stage extraction error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return {'is_table': False, 'error': str(e), 'method': 'two_stage'} |
| |
|
| |
|
| | def ocr_single_cell(cell_image: np.ndarray, ocr_predictor) -> str: |
| | """ |
| | Run OCR on a single cell image using docTR. |
| | Returns the extracted text with lines joined. |
| | """ |
| | try: |
| | if cell_image.size == 0: |
| | return '' |
| |
|
| | |
| | pil_img = Image.fromarray(cell_image) |
| | img_byte_arr = io.BytesIO() |
| | pil_img.save(img_byte_arr, format='PNG') |
| | img_bytes = img_byte_arr.getvalue() |
| |
|
| | |
| | doc = DocumentFile.from_images([img_bytes]) |
| | result = ocr_predictor(doc) |
| |
|
| | |
| | lines = [] |
| | for page in result.pages: |
| | for block in page.blocks: |
| | for line in block.lines: |
| | line_text = ' '.join(word.value for word in line.words) |
| | if line_text.strip(): |
| | lines.append(line_text.strip()) |
| |
|
| | |
| | return ' '.join(lines) |
| |
|
| | except Exception as e: |
| | print(f"Cell OCR error: {e}") |
| | return '' |
| |
|
| |
|
| | def extract_text_two_stage(image_bytes: bytes, img_width: int, img_height: int, ocr_predictor) -> tuple: |
| | """ |
| | Two-stage table extraction wrapper. |
| | Returns (markdown_text, table_data). |
| | """ |
| | table_data = extract_tables_two_stage(image_bytes, img_width, img_height, ocr_predictor) |
| |
|
| | if table_data.get('is_table'): |
| | markdown_table = format_table_as_markdown(table_data) |
| | return markdown_table, table_data |
| | else: |
| | return '', {'is_table': False, 'method': 'two_stage'} |
| |
|
| |
|
| | |
| |
|
| | def extract_tables_borderless(doctr_result, min_columns: int = 2, min_rows: int = 2) -> dict: |
| | """ |
| | Detect borderless tables by analyzing text positions from docTR. |
| | Works when there are no visible grid lines - uses whitespace gaps to infer structure. |
| | |
| | Algorithm: |
| | 1. Collect all words with positions |
| | 2. Find column boundaries by detecting consistent vertical gaps |
| | 3. Group words into rows by y-position clustering |
| | 4. Handle multi-line cells by merging text within same cell bounds |
| | """ |
| | try: |
| | |
| | all_words = [] |
| | for page in doctr_result.pages: |
| | for block in page.blocks: |
| | for line in block.lines: |
| | for word in line.words: |
| | x_min, y_min = word.geometry[0] |
| | x_max, y_max = word.geometry[1] |
| | all_words.append({ |
| | 'text': word.value, |
| | 'x_min': x_min, |
| | 'x_max': x_max, |
| | 'y_min': y_min, |
| | 'y_max': y_max, |
| | 'x_center': (x_min + x_max) / 2, |
| | 'y_center': (y_min + y_max) / 2, |
| | 'height': y_max - y_min |
| | }) |
| |
|
| | if len(all_words) < 4: |
| | return {'is_table': False, 'reason': 'Too few words'} |
| |
|
| | print(f"Borderless: Analyzing {len(all_words)} words") |
| |
|
| | |
| | columns = detect_column_boundaries(all_words) |
| |
|
| | if len(columns) < min_columns: |
| | print(f"Borderless: Only {len(columns)} columns detected, need {min_columns}") |
| | return {'is_table': False, 'reason': f'Only {len(columns)} columns found'} |
| |
|
| | print(f"Borderless: Detected {len(columns)} columns") |
| |
|
| | |
| | rows = detect_row_boundaries(all_words) |
| |
|
| | if len(rows) < min_rows: |
| | print(f"Borderless: Only {len(rows)} rows detected, need {min_rows}") |
| | return {'is_table': False, 'reason': f'Only {len(rows)} rows found'} |
| |
|
| | print(f"Borderless: Detected {len(rows)} rows") |
| |
|
| | |
| | cells = build_table_cells(all_words, columns, rows) |
| |
|
| | |
| | non_empty_cells = sum(1 for row in cells for cell in row if cell.strip()) |
| | total_cells = len(cells) * len(columns) |
| | fill_ratio = non_empty_cells / total_cells if total_cells > 0 else 0 |
| |
|
| | if fill_ratio < 0.3: |
| | print(f"Borderless: Low fill ratio {fill_ratio:.2f}, probably not a table") |
| | return {'is_table': False, 'reason': f'Low fill ratio: {fill_ratio:.2f}'} |
| |
|
| | print(f"Borderless: Built {len(cells)}x{len(columns)} table with {fill_ratio:.2f} fill ratio") |
| |
|
| | return { |
| | 'is_table': True, |
| | 'cells': cells, |
| | 'num_rows': len(cells), |
| | 'num_columns': len(columns), |
| | 'method': 'borderless', |
| | 'fill_ratio': fill_ratio |
| | } |
| |
|
| | except Exception as e: |
| | print(f"Borderless extraction error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return {'is_table': False, 'error': str(e), 'method': 'borderless'} |
| |
|
| |
|
| | def detect_column_boundaries(words: list, min_gap: float = 0.03) -> list: |
| | """ |
| | Detect column boundaries by finding consistent vertical gaps in text. |
| | Returns list of (x_start, x_end) tuples for each column. |
| | """ |
| | if not words: |
| | return [] |
| |
|
| | |
| | x_positions = sorted(set(w['x_min'] for w in words)) |
| |
|
| | if len(x_positions) < 2: |
| | return [(0, 1)] |
| |
|
| | |
| | gaps = [] |
| | for i in range(1, len(x_positions)): |
| | gap = x_positions[i] - x_positions[i-1] |
| | gaps.append((x_positions[i-1], x_positions[i], gap)) |
| |
|
| | |
| | |
| | significant_gaps = [] |
| | for x1, x2, gap in gaps: |
| | if gap >= min_gap: |
| | |
| | gap_mid = (x1 + x2) / 2 |
| | rows_with_gap = count_rows_with_gap(words, gap_mid, gap * 0.5) |
| | if rows_with_gap >= 2: |
| | significant_gaps.append(gap_mid) |
| |
|
| | |
| | if not significant_gaps: |
| | |
| | return cluster_columns_by_alignment(words, min_gap) |
| |
|
| | |
| | significant_gaps = sorted(set(significant_gaps)) |
| |
|
| | columns = [] |
| | prev_x = 0 |
| | for gap_x in significant_gaps: |
| | columns.append((prev_x, gap_x)) |
| | prev_x = gap_x |
| | columns.append((prev_x, 1.0)) |
| |
|
| | return columns |
| |
|
| |
|
| | def count_rows_with_gap(words: list, gap_x: float, tolerance: float) -> int: |
| | """Count how many rows have a gap at the given x position.""" |
| | |
| | y_groups = {} |
| | for word in words: |
| | y_key = round(word['y_center'] * 20) / 20 |
| | if y_key not in y_groups: |
| | y_groups[y_key] = [] |
| | y_groups[y_key].append(word) |
| |
|
| | rows_with_gap = 0 |
| | for y_key, row_words in y_groups.items(): |
| | |
| | words_before = [w for w in row_words if w['x_max'] < gap_x - tolerance] |
| | words_after = [w for w in row_words if w['x_min'] > gap_x + tolerance] |
| |
|
| | if words_before and words_after: |
| | rows_with_gap += 1 |
| |
|
| | return rows_with_gap |
| |
|
| |
|
| | def cluster_columns_by_alignment(words: list, min_gap: float) -> list: |
| | """ |
| | Cluster columns by finding words that align vertically. |
| | Used when gap detection doesn't find clear separators. |
| | """ |
| | |
| | x_mins = sorted(w['x_min'] for w in words) |
| |
|
| | clusters = [] |
| | current_cluster = [x_mins[0]] |
| |
|
| | for i in range(1, len(x_mins)): |
| | if x_mins[i] - x_mins[i-1] <= min_gap: |
| | current_cluster.append(x_mins[i]) |
| | else: |
| | clusters.append(current_cluster) |
| | current_cluster = [x_mins[i]] |
| | clusters.append(current_cluster) |
| |
|
| | |
| | if len(clusters) < 2: |
| | return [(0, 1)] |
| |
|
| | columns = [] |
| | for i, cluster in enumerate(clusters): |
| | x_start = min(cluster) - 0.01 |
| | if i < len(clusters) - 1: |
| | x_end = (max(cluster) + min(clusters[i+1])) / 2 |
| | else: |
| | x_end = 1.0 |
| | columns.append((max(0, x_start), min(1, x_end))) |
| |
|
| | return columns |
| |
|
| |
|
| | def detect_row_boundaries(words: list, y_tolerance: float = 0.02) -> list: |
| | """ |
| | Detect row boundaries by clustering y-positions. |
| | Returns list of (y_start, y_end) tuples for each row. |
| | """ |
| | if not words: |
| | return [] |
| |
|
| | |
| | sorted_by_y = sorted(words, key=lambda w: w['y_min']) |
| |
|
| | |
| | rows = [] |
| | current_row = [sorted_by_y[0]] |
| |
|
| | for i in range(1, len(sorted_by_y)): |
| | word = sorted_by_y[i] |
| | prev_word = current_row[-1] |
| |
|
| | |
| | |
| | y_overlap = min(word['y_max'], prev_word['y_max']) - max(word['y_min'], prev_word['y_min']) |
| | min_height = min(word['height'], prev_word['height']) |
| |
|
| | if y_overlap > min_height * 0.3 or abs(word['y_center'] - prev_word['y_center']) < y_tolerance: |
| | current_row.append(word) |
| | else: |
| | |
| | row_y_min = min(w['y_min'] for w in current_row) |
| | row_y_max = max(w['y_max'] for w in current_row) |
| | rows.append((row_y_min, row_y_max, current_row)) |
| | current_row = [word] |
| |
|
| | |
| | if current_row: |
| | row_y_min = min(w['y_min'] for w in current_row) |
| | row_y_max = max(w['y_max'] for w in current_row) |
| | rows.append((row_y_min, row_y_max, current_row)) |
| |
|
| | return rows |
| |
|
| |
|
| | def build_table_cells(words: list, columns: list, rows: list) -> list: |
| | """ |
| | Build table cells by assigning words to their respective cells. |
| | Handles multi-line text within cells. |
| | """ |
| | num_cols = len(columns) |
| | table = [] |
| |
|
| | for row_y_min, row_y_max, row_words in rows: |
| | row_cells = [''] * num_cols |
| |
|
| | |
| | row_words_sorted = sorted(row_words, key=lambda w: w['x_min']) |
| |
|
| | for word in row_words_sorted: |
| | |
| | word_x = word['x_min'] |
| |
|
| | for col_idx, (col_start, col_end) in enumerate(columns): |
| | if col_start <= word_x < col_end: |
| | |
| | if row_cells[col_idx]: |
| | row_cells[col_idx] += ' ' + word['text'] |
| | else: |
| | row_cells[col_idx] = word['text'] |
| | break |
| |
|
| | table.append(row_cells) |
| |
|
| | return table |
| |
|
| |
|
| | def extract_text_borderless(doctr_result) -> tuple: |
| | """ |
| | Borderless table extraction wrapper. |
| | Returns (markdown_text, table_data). |
| | """ |
| | table_data = extract_tables_borderless(doctr_result) |
| |
|
| | if table_data.get('is_table'): |
| | markdown_table = format_table_as_markdown(table_data) |
| | return markdown_table, table_data |
| | else: |
| | return '', {'is_table': False, 'method': 'borderless'} |
| |
|
| |
|
| | |
| |
|
| | def extract_tables_block_geometry(doctr_result, min_columns: int = 2, min_rows: int = 2) -> dict: |
| | """ |
| | Detect tables using docTR's block-level grouping from .export(). |
| | If multiple blocks exist at similar y-positions but different x-positions, |
| | they likely represent table columns. |
| | """ |
| | try: |
| | exported = doctr_result.export() |
| |
|
| | if not exported or 'pages' not in exported or not exported['pages']: |
| | return {'is_table': False, 'reason': 'No pages in export', 'method': 'block_geometry'} |
| |
|
| | page = exported['pages'][0] |
| | blocks = page.get('blocks', []) |
| |
|
| | if len(blocks) < 2: |
| | return {'is_table': False, 'reason': f'Only {len(blocks)} blocks found', 'method': 'block_geometry'} |
| |
|
| | print(f"Block-geometry: Analyzing {len(blocks)} blocks") |
| |
|
| | |
| | block_data = [] |
| | for block in blocks: |
| | geometry = block.get('geometry', []) |
| | if len(geometry) < 2: |
| | continue |
| |
|
| | x_min, y_min = geometry[0] |
| | x_max, y_max = geometry[1] |
| |
|
| | block_text_parts = [] |
| | for line in block.get('lines', []): |
| | line_words = [] |
| | for word in line.get('words', []): |
| | line_words.append(word.get('value', '')) |
| | if line_words: |
| | block_text_parts.append(' '.join(line_words)) |
| |
|
| | block_text = ' '.join(block_text_parts).strip() |
| |
|
| | if block_text: |
| | block_data.append({ |
| | 'text': block_text, |
| | 'x_min': x_min, |
| | 'x_max': x_max, |
| | 'y_min': y_min, |
| | 'y_max': y_max, |
| | 'y_center': (y_min + y_max) / 2, |
| | 'x_center': (x_min + x_max) / 2, |
| | 'height': y_max - y_min, |
| | }) |
| |
|
| | if len(block_data) < min_columns: |
| | return {'is_table': False, 'reason': f'Only {len(block_data)} text blocks', 'method': 'block_geometry'} |
| |
|
| | |
| | block_data.sort(key=lambda b: b['y_min']) |
| |
|
| | rows = [] |
| | current_row = [block_data[0]] |
| |
|
| | for i in range(1, len(block_data)): |
| | block = block_data[i] |
| | prev_block = current_row[-1] |
| |
|
| | y_overlap = min(block['y_max'], prev_block['y_max']) - max(block['y_min'], prev_block['y_min']) |
| | min_height = min(block['height'], prev_block['height']) |
| |
|
| | if min_height > 0 and y_overlap / min_height > 0.3: |
| | current_row.append(block) |
| | else: |
| | rows.append(current_row) |
| | current_row = [block] |
| |
|
| | if current_row: |
| | rows.append(current_row) |
| |
|
| | print(f"Block-geometry: Found {len(rows)} potential rows") |
| |
|
| | |
| | multi_block_rows = [row for row in rows if len(row) >= min_columns] |
| |
|
| | if len(multi_block_rows) < min_rows: |
| | print(f"Block-geometry: Only {len(multi_block_rows)} multi-block rows, need {min_rows}") |
| | return {'is_table': False, 'reason': f'Only {len(multi_block_rows)} multi-block rows', 'method': 'block_geometry'} |
| |
|
| | |
| | col_counts = [len(row) for row in multi_block_rows] |
| | most_common_count = max(set(col_counts), key=col_counts.count) |
| | consistent_rows = [row for row in multi_block_rows if len(row) == most_common_count] |
| |
|
| | if len(consistent_rows) < min_rows: |
| | print(f"Block-geometry: Only {len(consistent_rows)} rows with {most_common_count} columns") |
| | return {'is_table': False, 'reason': 'Inconsistent column counts', 'method': 'block_geometry'} |
| |
|
| | print(f"Block-geometry: {len(consistent_rows)} rows with {most_common_count} columns") |
| |
|
| | |
| | table_cells = [] |
| | for row in consistent_rows: |
| | row_sorted = sorted(row, key=lambda b: b['x_min']) |
| | row_texts = [b['text'] for b in row_sorted] |
| | table_cells.append(row_texts) |
| |
|
| | |
| | max_cols = max(len(row) for row in table_cells) if table_cells else 0 |
| | normalized_cells = [] |
| | for row in table_cells: |
| | normalized_row = row + [''] * (max_cols - len(row)) |
| | normalized_cells.append(normalized_row) |
| |
|
| | |
| | non_empty = sum(1 for row in normalized_cells for cell in row if cell.strip()) |
| | total = len(normalized_cells) * max_cols |
| | fill_ratio = non_empty / total if total > 0 else 0 |
| |
|
| | if fill_ratio < 0.3: |
| | print(f"Block-geometry: Low fill ratio {fill_ratio:.2f}") |
| | return {'is_table': False, 'reason': f'Low fill ratio: {fill_ratio:.2f}', 'method': 'block_geometry'} |
| |
|
| | print(f"Block-geometry: Built {len(normalized_cells)}x{max_cols} table with {fill_ratio:.2f} fill ratio") |
| |
|
| | return { |
| | 'is_table': True, |
| | 'cells': normalized_cells, |
| | 'num_rows': len(normalized_cells), |
| | 'num_columns': max_cols, |
| | 'method': 'block_geometry', |
| | 'fill_ratio': fill_ratio, |
| | } |
| |
|
| | except Exception as e: |
| | print(f"Block-geometry extraction error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return {'is_table': False, 'error': str(e), 'method': 'block_geometry'} |
| |
|
| |
|
| | def extract_text_block_geometry(doctr_result) -> tuple: |
| | """Block-geometry table extraction wrapper.""" |
| | table_data = extract_tables_block_geometry(doctr_result) |
| |
|
| | if table_data.get('is_table'): |
| | markdown_table = format_table_as_markdown(table_data) |
| | return markdown_table, table_data |
| | else: |
| | return '', {'is_table': False, 'method': 'block_geometry'} |
| |
|
| |
|
| | def extract_text_structured(result) -> str: |
| | """ |
| | Extract text from docTR result preserving logical structure. |
| | Explicitly sorts words by x-coordinate and lines by y-coordinate. |
| | """ |
| | all_lines = [] |
| |
|
| | for page in result.pages: |
| | for block in page.blocks: |
| | for line in block.lines: |
| | |
| | words_data = [] |
| | for word in line.words: |
| | |
| | x_pos = word.geometry[0][0] |
| | y_pos = word.geometry[0][1] |
| | words_data.append({ |
| | 'text': word.value, |
| | 'x': x_pos, |
| | 'y': y_pos |
| | }) |
| |
|
| | if not words_data: |
| | continue |
| |
|
| | |
| | words_data.sort(key=lambda w: w['x']) |
| |
|
| | line_text = " ".join([w['text'] for w in words_data]) |
| | avg_y = sum(w['y'] for w in words_data) / len(words_data) |
| | min_x = min(w['x'] for w in words_data) |
| |
|
| | if line_text.strip(): |
| | all_lines.append({ |
| | 'text': line_text.strip(), |
| | 'y': avg_y, |
| | 'x': min_x |
| | }) |
| |
|
| | |
| | all_lines.sort(key=lambda l: (round(l['y'] * 20) / 20, l['x'])) |
| |
|
| | |
| | result_lines = [] |
| | prev_y_group = -1 |
| | current_line_parts = [] |
| |
|
| | for line_info in all_lines: |
| | current_y_group = round(line_info['y'] * 20) / 20 |
| |
|
| | if prev_y_group != -1 and current_y_group != prev_y_group: |
| | if current_line_parts: |
| | result_lines.append(" ".join(current_line_parts)) |
| | current_line_parts = [line_info['text']] |
| | else: |
| | current_line_parts.append(line_info['text']) |
| |
|
| | prev_y_group = current_y_group |
| |
|
| | if current_line_parts: |
| | result_lines.append(" ".join(current_line_parts)) |
| |
|
| | return "\n".join(result_lines) |
| |
|
| | def generate_synthesized_image(doctr_result) -> Optional[str]: |
| | """ |
| | Generate a reconstructed document image using docTR's synthesize() method. |
| | Returns a base64-encoded PNG string, or None if synthesis fails. |
| | """ |
| | try: |
| | synthetic_pages = doctr_result.synthesize() |
| |
|
| | if not synthetic_pages or len(synthetic_pages) == 0: |
| | print("Synthesize: No pages returned") |
| | return None |
| |
|
| | |
| | synth_img = synthetic_pages[0] |
| |
|
| | |
| | pil_img = Image.fromarray(synth_img) |
| | img_byte_arr = io.BytesIO() |
| | pil_img.save(img_byte_arr, format='PNG') |
| | img_bytes = img_byte_arr.getvalue() |
| |
|
| | b64_string = base64.b64encode(img_bytes).decode('utf-8') |
| | print(f"Synthesize: Generated image ({len(b64_string)} chars base64)") |
| | return b64_string |
| |
|
| | except Exception as e: |
| | print(f"Synthesize error: {e}") |
| | return None |
| |
|
| |
|
| | def extract_words_with_boxes(result) -> list: |
| | """ |
| | Extract all words with their bounding boxes and confidence from docTR result. |
| | Returns list of {word, confidence, bbox} where bbox is [[x0,y0], [x1,y1]] normalized 0-1. |
| | """ |
| | words_with_boxes = [] |
| |
|
| | for page in result.pages: |
| | for block in page.blocks: |
| | for line in block.lines: |
| | for word in line.words: |
| | |
| | bbox = [ |
| | [word.geometry[0][0], word.geometry[0][1]], |
| | [word.geometry[1][0], word.geometry[1][1]] |
| | ] |
| | words_with_boxes.append({ |
| | 'word': word.value, |
| | 'confidence': word.confidence, |
| | 'bbox': bbox |
| | }) |
| |
|
| | return words_with_boxes |
| |
|
| | def check_drug_interactions(detected_drugs: List[str]) -> List[Dict]: |
| | """ |
| | Check for known interactions between detected drugs. |
| | Returns list of interaction warnings. |
| | """ |
| | interactions = [] |
| | drugs_lower = [d.lower().strip() for d in detected_drugs] |
| |
|
| | |
| | for i, drug1 in enumerate(drugs_lower): |
| | for drug2 in drugs_lower[i+1:]: |
| | |
| | if drug1 in DRUG_INTERACTIONS: |
| | if drug2 in DRUG_INTERACTIONS[drug1]: |
| | interaction = DRUG_INTERACTIONS[drug1][drug2] |
| | interactions.append({ |
| | 'drug1': detected_drugs[i], |
| | 'drug2': detected_drugs[drugs_lower.index(drug2)], |
| | 'severity': interaction.get('severity', 'info'), |
| | 'description': interaction.get('description', ''), |
| | 'recommendation': interaction.get('recommendation'), |
| | }) |
| | |
| | elif drug2 in DRUG_INTERACTIONS: |
| | if drug1 in DRUG_INTERACTIONS[drug2]: |
| | interaction = DRUG_INTERACTIONS[drug2][drug1] |
| | interactions.append({ |
| | 'drug1': detected_drugs[drugs_lower.index(drug2)], |
| | 'drug2': detected_drugs[i], |
| | 'severity': interaction.get('severity', 'info'), |
| | 'description': interaction.get('description', ''), |
| | 'recommendation': interaction.get('recommendation'), |
| | }) |
| |
|
| | return interactions |
| |
|
| | |
| |
|
| | def parse_reference_range(range_str: str): |
| | """ |
| | Parse reference range strings from lab documents. |
| | Formats: "(13.5 - 18.0)", "(< 200)", "(> 60)", "(< 0.61)" |
| | Returns: (low, high) where either can be None |
| | """ |
| | if not range_str: |
| | return None, None |
| |
|
| | |
| | s = range_str.strip().strip('()').strip() |
| |
|
| | |
| | m = re.match(r'^[<\u2264]\s*(\d+\.?\d*)$', s) |
| | if m: |
| | return None, float(m.group(1)) |
| |
|
| | |
| | m = re.match(r'^[>\u2265]\s*(\d+\.?\d*)$', s) |
| | if m: |
| | return float(m.group(1)), None |
| |
|
| | |
| | m = re.match(r'(\d+\.?\d*)\s*[-\u2013]\s*(\d+\.?\d*)', s) |
| | if m: |
| | return float(m.group(1)), float(m.group(2)) |
| |
|
| | return None, None |
| |
|
| |
|
| | def extract_lab_values_from_words(words_with_boxes: List[Dict]) -> List[Dict]: |
| | """ |
| | Extract lab values using word positions from docTR. |
| | Groups words into rows by y-coordinate, then identifies columns |
| | (test name, value, unit, range) by x-position within each row. |
| | This is the most reliable method since it uses spatial layout. |
| | """ |
| | extracted = [] |
| | if not words_with_boxes: |
| | return extracted |
| |
|
| | |
| | ROW_TOLERANCE = 0.015 |
| | rows = [] |
| | sorted_words = sorted(words_with_boxes, key=lambda w: (w['bbox'][0][1], w['bbox'][0][0])) |
| |
|
| | current_row = [] |
| | current_y = None |
| |
|
| | for word_info in sorted_words: |
| | y_center = (word_info['bbox'][0][1] + word_info['bbox'][1][1]) / 2 |
| | if current_y is None or abs(y_center - current_y) < ROW_TOLERANCE: |
| | current_row.append(word_info) |
| | if current_y is None: |
| | current_y = y_center |
| | else: |
| | current_y = (current_y + y_center) / 2 |
| | else: |
| | if current_row: |
| | rows.append(sorted(current_row, key=lambda w: w['bbox'][0][0])) |
| | current_row = [word_info] |
| | current_y = y_center |
| |
|
| | if current_row: |
| | rows.append(sorted(current_row, key=lambda w: w['bbox'][0][0])) |
| |
|
| | |
| | UNITS = {'mg/dl', 'mmol/l', 'g/dl', 'u/l', 'miu/l', 'ng/dl', 'pg/ml', |
| | 'ug/dl', 'ng/ml', 'fl', 'pg', '%', 'mm/hr', 'mg/l', 'mg/mmol', |
| | 'ug/l', 'ml/min/1.73m2'} |
| |
|
| | SKIP_WORDS = {'result', 'unit', 'ref.range', 'ref', 'range', 'reference', |
| | 'date', 'request', 'no', 'no:'} |
| |
|
| | for row in rows: |
| | words_text = [w['word'] for w in row] |
| | row_str = ' '.join(words_text).lower() |
| |
|
| | |
| | if 'result' in row_str and ('unit' in row_str or 'ref' in row_str): |
| | continue |
| | if 'profile' in row_str and len(words_text) <= 3: |
| | continue |
| | if 'function' in row_str and len(words_text) <= 3: |
| | continue |
| |
|
| | |
| | name_parts = [] |
| | value = None |
| | unit = '' |
| | range_parts = [] |
| | is_flagged = False |
| | in_range = False |
| |
|
| | for w in row: |
| | word = w['word'].strip() |
| | word_lower = word.lower().strip('()') |
| |
|
| | if not word: |
| | continue |
| |
|
| | |
| | if '(' in word or in_range: |
| | in_range = True |
| | range_parts.append(word) |
| | if ')' in word: |
| | in_range = False |
| | continue |
| |
|
| | |
| | if word == '*': |
| | is_flagged = True |
| | continue |
| |
|
| | |
| | if word_lower in UNITS or word_lower.replace('/', '').replace('.', '').replace('1', '').replace('3', '').replace('7', '').replace('m', '').replace('2', '') == '': |
| | cleaned_unit = word_lower |
| | if cleaned_unit in UNITS: |
| | unit = word |
| | continue |
| |
|
| | |
| | if 'x10' in word_lower or '10⁹' in word or '10¹²' in word: |
| | unit = word |
| | continue |
| |
|
| | |
| | cleaned_word = word.lstrip('*').strip() |
| | try: |
| | num = float(cleaned_word) |
| | if value is None: |
| | value = num |
| | if '*' in word: |
| | is_flagged = True |
| | continue |
| | except ValueError: |
| | pass |
| |
|
| | |
| | if word_lower in SKIP_WORDS: |
| | continue |
| |
|
| | |
| | if all('\u4e00' <= c <= '\u9fff' or c in '()()' for c in word): |
| | continue |
| |
|
| | |
| | if any(c.isalpha() for c in word): |
| | name_parts.append(word) |
| |
|
| | |
| | range_str = ' '.join(range_parts).strip('() ') |
| | ref_low, ref_high = parse_reference_range(range_str) |
| |
|
| | test_name = ' '.join(name_parts).strip() |
| |
|
| | |
| | if test_name and value is not None and (ref_low is not None or ref_high is not None): |
| | |
| | if test_name.upper() == test_name and len(test_name.split()) > 2: |
| | continue |
| |
|
| | extracted.append({ |
| | 'test_name': test_name, |
| | 'value': value, |
| | 'unit': unit, |
| | 'ref_low': ref_low, |
| | 'ref_high': ref_high, |
| | 'ref_range_str': range_str, |
| | 'is_flagged_in_document': is_flagged, |
| | }) |
| |
|
| | return extracted |
| |
|
| |
|
| | def extract_lab_values_from_text(structured_text: str) -> List[Dict]: |
| | """ |
| | Extract test name, value, unit, and reference range from OCR structured text. |
| | Handles the document format: TestName [ChineseName] Result Unit (Range) |
| | """ |
| | extracted = [] |
| | if not structured_text: |
| | return extracted |
| |
|
| | lines = structured_text.split('\n') |
| | for line in lines: |
| | line = line.strip() |
| | if not line or len(line) < 5: |
| | continue |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | range_match = re.search(r'\(([<>\u2264\u2265]?\s*\d+\.?\d*(?:\s*[-\u2013]\s*\d+\.?\d*)?)\)\s*$', line) |
| | ref_range_str = None |
| | ref_low, ref_high = None, None |
| | if range_match: |
| | ref_range_str = range_match.group(1) |
| | ref_low, ref_high = parse_reference_range(ref_range_str) |
| | line = line[:range_match.start()].strip() |
| |
|
| | |
| | if ref_low is None and ref_high is None: |
| | continue |
| |
|
| | |
| | |
| | value_match = re.search(r'\*?\s*(\d+\.?\d*)\s+(mg/dL|mmol/L|g/dL|U/L|mIU/L|ng/dL|pg/mL|ug/dL|ng/mL|fL|pg|%|mm/hr|mg/L|mg/mmol|x10\^?\d+/L|mL/min/1\.73m2|ug/L)', line, re.IGNORECASE) |
| |
|
| | if not value_match: |
| | |
| | value_match = re.search(r'\*?\s*(\d+\.?\d+)\s*$', line) |
| | if not value_match: |
| | |
| | value_match = re.search(r'(?:[\u4e00-\u9fff\s]+|\s+)\*?\s*(\d+\.?\d*)\s', line) |
| |
|
| | if not value_match: |
| | continue |
| |
|
| | try: |
| | value = float(value_match.group(1)) |
| | except (ValueError, IndexError): |
| | continue |
| |
|
| | |
| | unit = '' |
| | if value_match.lastindex and value_match.lastindex >= 2: |
| | unit = value_match.group(2) |
| |
|
| | |
| | |
| | chinese_start = re.search(r'[\u4e00-\u9fff]', line) |
| | if chinese_start: |
| | test_name = line[:chinese_start.start()].strip() |
| | else: |
| | test_name = line[:value_match.start()].strip() |
| |
|
| | |
| | test_name = test_name.strip().rstrip(':').strip() |
| |
|
| | |
| | is_flagged = '*' in line[:value_match.end()] |
| |
|
| | if test_name and len(test_name) >= 2: |
| | extracted.append({ |
| | 'test_name': test_name, |
| | 'value': value, |
| | 'unit': unit, |
| | 'ref_low': ref_low, |
| | 'ref_high': ref_high, |
| | 'ref_range_str': ref_range_str or '', |
| | 'is_flagged_in_document': is_flagged, |
| | }) |
| |
|
| | return extracted |
| |
|
| |
|
| | def extract_lab_values_from_table(table_data: Dict) -> List[Dict]: |
| | """ |
| | Extract lab values from structured table data. |
| | table_data has 'cells' (list of rows, each row is list of cell strings). |
| | """ |
| | extracted = [] |
| | cells = table_data.get('cells', []) |
| | if not cells or len(cells) < 2: |
| | return extracted |
| |
|
| | |
| | header_row = cells[0] if cells else [] |
| | name_col = -1 |
| | value_col = -1 |
| | unit_col = -1 |
| | range_col = -1 |
| |
|
| | for i, cell in enumerate(header_row): |
| | cell_lower = cell.strip().lower() |
| | if any(kw in cell_lower for kw in ['test', 'name', 'parameter', 'investigation']): |
| | name_col = i |
| | elif 'result' in cell_lower: |
| | value_col = i |
| | elif 'unit' in cell_lower: |
| | unit_col = i |
| | elif any(kw in cell_lower for kw in ['ref', 'range', 'normal', 'reference']): |
| | range_col = i |
| |
|
| | |
| | if name_col == -1 or value_col == -1: |
| | |
| | for row in cells[1:]: |
| | if len(row) < 2: |
| | continue |
| |
|
| | test_name = None |
| | value = None |
| | unit = '' |
| | ref_range_str = '' |
| | is_flagged = False |
| |
|
| | for cell in row: |
| | cell_stripped = cell.strip() |
| | if not cell_stripped: |
| | continue |
| |
|
| | |
| | range_m = re.match(r'^\(([<>\u2264\u2265]?\s*\d+\.?\d*(?:\s*[-\u2013]\s*\d+\.?\d*)?)\)$', cell_stripped) |
| | if range_m: |
| | ref_range_str = range_m.group(1) |
| | continue |
| |
|
| | |
| | val_m = re.match(r'^\*?\s*(\d+\.?\d*)$', cell_stripped) |
| | if val_m and test_name is not None and value is None: |
| | value = float(val_m.group(1)) |
| | is_flagged = '*' in cell_stripped |
| | continue |
| |
|
| | |
| | if cell_stripped.lower() in ['mg/dl', 'mmol/l', 'g/dl', 'u/l', 'miu/l', 'ng/dl', |
| | 'pg/ml', 'ug/dl', 'ng/ml', 'fl', 'pg', '%', 'mm/hr', |
| | 'mg/l', 'mg/mmol', 'ug/l', 'ml/min/1.73m2']: |
| | unit = cell_stripped |
| | continue |
| |
|
| | |
| | if re.match(r'x10\^?\d+/L', cell_stripped, re.IGNORECASE): |
| | unit = cell_stripped |
| | continue |
| |
|
| | |
| | if any(c.isalpha() for c in cell_stripped) and test_name is None: |
| | |
| | if not all('\u4e00' <= c <= '\u9fff' or c.isspace() for c in cell_stripped): |
| | test_name = cell_stripped |
| |
|
| | if test_name and value is not None and ref_range_str: |
| | ref_low, ref_high = parse_reference_range(ref_range_str) |
| | if ref_low is not None or ref_high is not None: |
| | extracted.append({ |
| | 'test_name': test_name, |
| | 'value': value, |
| | 'unit': unit, |
| | 'ref_low': ref_low, |
| | 'ref_high': ref_high, |
| | 'ref_range_str': ref_range_str, |
| | 'is_flagged_in_document': is_flagged, |
| | }) |
| | return extracted |
| |
|
| | |
| | for row in cells[1:]: |
| | if len(row) <= max(name_col, value_col): |
| | continue |
| |
|
| | test_name = row[name_col].strip() if name_col < len(row) else '' |
| | value_str = row[value_col].strip() if value_col < len(row) else '' |
| | unit = row[unit_col].strip() if unit_col >= 0 and unit_col < len(row) else '' |
| | ref_range_str = row[range_col].strip().strip('()') if range_col >= 0 and range_col < len(row) else '' |
| |
|
| | if not test_name or not value_str: |
| | continue |
| |
|
| | |
| | is_flagged = '*' in value_str |
| | val_m = re.search(r'(\d+\.?\d*)', value_str) |
| | if not val_m: |
| | continue |
| |
|
| | value = float(val_m.group(1)) |
| | ref_low, ref_high = parse_reference_range(ref_range_str) |
| |
|
| | if ref_low is not None or ref_high is not None: |
| | extracted.append({ |
| | 'test_name': test_name, |
| | 'value': value, |
| | 'unit': unit, |
| | 'ref_low': ref_low, |
| | 'ref_high': ref_high, |
| | 'ref_range_str': ref_range_str, |
| | 'is_flagged_in_document': is_flagged, |
| | }) |
| |
|
| | return extracted |
| |
|
| |
|
| | def classify_lab_value(value: float, ref_low, ref_high) -> str: |
| | """ |
| | Classify a lab value against reference range. |
| | Returns: 'critical_low', 'low', 'normal', 'high', 'critical_high' |
| | """ |
| | if ref_low is not None and value < ref_low: |
| | |
| | if value < ref_low * 0.7: |
| | return 'critical_low' |
| | return 'low' |
| |
|
| | if ref_high is not None and value > ref_high: |
| | |
| | if value > ref_high * 1.5: |
| | return 'critical_high' |
| | return 'high' |
| |
|
| | return 'normal' |
| |
|
| |
|
| | def match_test_to_medlineplus(test_name: str) -> Optional[Dict]: |
| | """ |
| | Fuzzy-match a test name against the MedlinePlus map. |
| | Returns the map entry if matched, None otherwise. |
| | """ |
| | if not MEDLINEPLUS_MAP: |
| | return None |
| |
|
| | name_lower = test_name.lower().strip() |
| |
|
| | |
| | if name_lower in MEDLINEPLUS_MAP: |
| | return MEDLINEPLUS_MAP[name_lower] |
| |
|
| | |
| | for key, data in MEDLINEPLUS_MAP.items(): |
| | aliases = [a.lower() for a in data.get('aliases', [])] |
| | if name_lower in aliases: |
| | return data |
| |
|
| | |
| | for key, data in MEDLINEPLUS_MAP.items(): |
| | if key in name_lower or name_lower in key: |
| | return data |
| | for alias in data.get('aliases', []): |
| | if alias.lower() in name_lower or name_lower in alias.lower(): |
| | return data |
| |
|
| | |
| | all_names = list(MEDLINEPLUS_MAP.keys()) |
| | for key, data in MEDLINEPLUS_MAP.items(): |
| | all_names.extend([a.lower() for a in data.get('aliases', [])]) |
| |
|
| | close = difflib.get_close_matches(name_lower, all_names, n=1, cutoff=0.7) |
| | if close: |
| | matched_name = close[0] |
| | if matched_name in MEDLINEPLUS_MAP: |
| | return MEDLINEPLUS_MAP[matched_name] |
| | for key, data in MEDLINEPLUS_MAP.items(): |
| | if matched_name in [a.lower() for a in data.get('aliases', [])]: |
| | return data |
| |
|
| | return None |
| |
|
| |
|
| | def get_medlineplus_info(slug: str, status: str) -> Dict: |
| | """ |
| | Get educational info from MedlinePlus cache for a given test slug and status. |
| | Falls back to fetching from MedlinePlus if not cached. |
| | """ |
| | url = f"https://medlineplus.gov/lab-tests/{slug}/" |
| |
|
| | |
| | if slug in MEDLINEPLUS_CACHE: |
| | cached = MEDLINEPLUS_CACHE[slug] |
| | direction = 'high' if 'high' in status else 'low' |
| | return { |
| | 'url': cached.get('url', url), |
| | 'description': cached.get(direction, ''), |
| | } |
| |
|
| | |
| | try: |
| | response = httpx.get(url, timeout=5.0, follow_redirects=True) |
| | if response.status_code == 200: |
| | soup = BeautifulSoup(response.text, 'html.parser') |
| |
|
| | |
| | results_section = None |
| | for heading in soup.find_all(['h2', 'h3']): |
| | if 'results' in heading.get_text().lower() and 'mean' in heading.get_text().lower(): |
| | results_section = heading |
| | break |
| |
|
| | description = '' |
| | if results_section: |
| | |
| | content_parts = [] |
| | for sibling in results_section.find_next_siblings(): |
| | if sibling.name in ['h2', 'h3']: |
| | break |
| | text = sibling.get_text(strip=True) |
| | if text: |
| | content_parts.append(text) |
| | description = ' '.join(content_parts[:3]) |
| |
|
| | |
| | MEDLINEPLUS_CACHE[slug] = { |
| | 'url': url, |
| | 'high': description, |
| | 'low': description, |
| | 'fetched_at': 'runtime' |
| | } |
| |
|
| | return { |
| | 'url': url, |
| | 'description': description, |
| | } |
| | except Exception as e: |
| | print(f"MedlinePlus fetch failed for {slug}: {e}") |
| |
|
| | return {'url': url, 'description': ''} |
| |
|
| |
|
| | def check_lab_values(structured_text: str, table_data: Optional[Dict], words_with_boxes: Optional[List[Dict]] = None) -> List[Dict]: |
| | """ |
| | Extract lab values from OCR output and check against reference ranges. |
| | Uses three extraction methods in priority order: |
| | 1. Word-position-based (most reliable — uses spatial layout from docTR) |
| | 2. Table-based (if table was detected) |
| | 3. Text regex-based (fallback) |
| | Returns list of lab anomaly results. |
| | """ |
| | |
| | extracted = [] |
| | if words_with_boxes: |
| | extracted = extract_lab_values_from_words(words_with_boxes) |
| | print(f"Lab extraction (word-position): found {len(extracted)} values") |
| |
|
| | |
| | if table_data and table_data.get('is_table'): |
| | table_extracted = extract_lab_values_from_table(table_data) |
| | print(f"Lab extraction (table): found {len(table_extracted)} values") |
| | existing_names = {e['test_name'].lower() for e in extracted} |
| | for te in table_extracted: |
| | if te['test_name'].lower() not in existing_names: |
| | extracted.append(te) |
| | existing_names.add(te['test_name'].lower()) |
| |
|
| | |
| | text_extracted = extract_lab_values_from_text(structured_text) |
| | print(f"Lab extraction (text-regex): found {len(text_extracted)} values") |
| |
|
| | |
| | existing_names = {e['test_name'].lower() for e in extracted} |
| | for te in text_extracted: |
| | if te['test_name'].lower() not in existing_names: |
| | extracted.append(te) |
| | existing_names.add(te['test_name'].lower()) |
| |
|
| | |
| | results = [] |
| | for item in extracted: |
| | status = classify_lab_value(item['value'], item['ref_low'], item['ref_high']) |
| |
|
| | |
| | if item['ref_low'] is not None and item['ref_high'] is not None: |
| | range_display = f"{item['ref_low']} - {item['ref_high']}" |
| | elif item['ref_high'] is not None: |
| | range_display = f"< {item['ref_high']}" |
| | elif item['ref_low'] is not None: |
| | range_display = f"> {item['ref_low']}" |
| | else: |
| | range_display = item.get('ref_range_str', '') |
| |
|
| | |
| | medlineplus_entry = match_test_to_medlineplus(item['test_name']) |
| | description = '' |
| | medlineplus_url = None |
| | category = 'General' |
| |
|
| | if medlineplus_entry: |
| | slug = medlineplus_entry.get('slug', '') |
| | category = medlineplus_entry.get('category', 'General') |
| | medlineplus_url = f"https://medlineplus.gov/lab-tests/{slug}/" |
| |
|
| | if status != 'normal' and slug: |
| | info = get_medlineplus_info(slug, status) |
| | description = info.get('description', '') |
| | medlineplus_url = info.get('url', medlineplus_url) |
| |
|
| | results.append({ |
| | 'test_name': item['test_name'], |
| | 'value': item['value'], |
| | 'unit': item['unit'], |
| | 'status': status, |
| | 'ref_low': item['ref_low'], |
| | 'ref_high': item['ref_high'], |
| | 'reference_range': range_display, |
| | 'category': category, |
| | 'description': description, |
| | 'medlineplus_url': medlineplus_url, |
| | 'is_flagged_in_document': item.get('is_flagged_in_document', False), |
| | }) |
| |
|
| | return results |
| |
|
| | def map_entities_to_boxes(entities: list, words_with_boxes: list, cleaned_text: str) -> list: |
| | """ |
| | Map NER entities back to word bounding boxes. |
| | Uses fuzzy matching to find entity words in OCR words. |
| | """ |
| | entities_with_boxes = [] |
| |
|
| | for entity in entities: |
| | entity_word = entity['word'].lower().strip() |
| | entity_parts = entity_word.split() |
| |
|
| | |
| | matched_boxes = [] |
| | for word_info in words_with_boxes: |
| | ocr_word = word_info['word'].lower().strip() |
| | |
| | for part in entity_parts: |
| | if part in ocr_word or ocr_word in part: |
| | matched_boxes.append(word_info['bbox']) |
| | break |
| |
|
| | |
| | if matched_boxes: |
| | |
| | min_x = min(box[0][0] for box in matched_boxes) |
| | min_y = min(box[0][1] for box in matched_boxes) |
| | max_x = max(box[1][0] for box in matched_boxes) |
| | max_y = max(box[1][1] for box in matched_boxes) |
| | combined_bbox = [[min_x, min_y], [max_x, max_y]] |
| | else: |
| | combined_bbox = None |
| |
|
| | entities_with_boxes.append({ |
| | 'entity_group': entity['entity_group'], |
| | 'score': entity['score'], |
| | 'word': entity['word'], |
| | 'bbox': combined_bbox |
| | }) |
| |
|
| | return entities_with_boxes |
| |
|
| | |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """Health check endpoint.""" |
| | return {"status": "running", "message": "ScanAssured OCR & NER API"} |
| |
|
| | @app.get("/models") |
| | async def get_available_models(): |
| | """Return all available OCR and NER models.""" |
| | return { |
| | "ocr_presets": [ |
| | { |
| | "id": preset_id, |
| | "name": preset_data["name"], |
| | "description": preset_data["description"] |
| | } |
| | for preset_id, preset_data in OCR_PRESETS.items() |
| | ], |
| | "ocr_detection_models": OCR_DETECTION_MODELS, |
| | "ocr_recognition_models": OCR_RECOGNITION_MODELS, |
| | "ner_models": { |
| | model_id: { |
| | "name": model_data["name"], |
| | "description": model_data["description"], |
| | "entities": model_data["entities"] |
| | } |
| | for model_id, model_data in NER_MODELS.items() |
| | }, |
| | "ocr_correction_model": { |
| | "id": "ner-dictionary", |
| | "name": "NER-Informed Dictionary Correction", |
| | "description": "Edit-distance correction against medical entity dictionaries, guided by NER entity labels", |
| | } |
| | } |
| |
|
| | @app.post("/process") |
| | async def process_image( |
| | file: UploadFile = File(...), |
| | ner_model_id: str = Form(...), |
| | ocr_preset: str = Form("balanced"), |
| | ocr_det_model: Optional[str] = Form(None), |
| | ocr_reco_model: Optional[str] = Form(None), |
| | enable_correction: str = Form("false"), |
| | correction_threshold: str = Form("0.75"), |
| | ): |
| | """Process an image with OCR and NER.""" |
| |
|
| | |
| | if ocr_det_model and ocr_reco_model: |
| | det_arch = ocr_det_model |
| | reco_arch = ocr_reco_model |
| | else: |
| | preset = OCR_PRESETS.get(ocr_preset, OCR_PRESETS["balanced"]) |
| | det_arch = preset["det"] |
| | reco_arch = preset["reco"] |
| |
|
| | |
| | if ner_model_id not in NER_MODELS: |
| | return JSONResponse( |
| | status_code=400, |
| | content={"detail": f"Unknown NER model: {ner_model_id}"} |
| | ) |
| |
|
| | |
| | ocr_predictor_instance = get_ocr_predictor(det_arch, reco_arch) |
| | if not ocr_predictor_instance: |
| | return JSONResponse( |
| | status_code=503, |
| | content={"detail": f"Failed to load OCR model: {det_arch}/{reco_arch}"} |
| | ) |
| |
|
| | |
| | ner_pipeline = get_ner_pipeline(ner_model_id) |
| | if not ner_pipeline: |
| | return JSONResponse( |
| | status_code=503, |
| | content={"detail": f"Failed to load NER model: {ner_model_id}"} |
| | ) |
| |
|
| | try: |
| | |
| | file_content = await file.read() |
| | preprocessed_img = preprocess_for_doctr(file_content) |
| |
|
| | |
| | print("Running docTR OCR...") |
| | |
| | pil_img = Image.fromarray(preprocessed_img) |
| | img_byte_arr = io.BytesIO() |
| | pil_img.save(img_byte_arr, format='PNG') |
| | img_bytes = img_byte_arr.getvalue() |
| |
|
| | doc = DocumentFile.from_images([img_bytes]) |
| | result = ocr_predictor_instance(doc) |
| |
|
| | |
| | img_height, img_width = preprocessed_img.shape[:2] |
| |
|
| | |
| | structured_text = extract_text_structured(result) |
| | cleaned_text = basic_cleanup(structured_text) |
| | words_with_boxes = extract_words_with_boxes(result) |
| |
|
| | print(f"OCR Structured Text:\n{structured_text[:500]}...") |
| | print(f"Extracted {len(words_with_boxes)} words with bounding boxes") |
| |
|
| | |
| | print("Generating synthesized document image...") |
| | synthesized_image = generate_synthesized_image(result) |
| |
|
| | |
| | print("Running Docling pipeline for comparison...") |
| | docling_result = run_docling_pipeline(file_content) |
| |
|
| | |
| | print("Running img2table for table detection (Method 1: integrated OCR)...") |
| | table_formatted_text, table_data = extract_text_with_table_detection( |
| | img_bytes, img_width, img_height |
| | ) |
| |
|
| | |
| | print("Running two-stage table detection (Method 2: structure + cell OCR)...") |
| | two_stage_text, two_stage_data = extract_text_two_stage( |
| | img_bytes, img_width, img_height, ocr_predictor_instance |
| | ) |
| |
|
| | |
| | print("Running borderless table detection (Method 3: text position analysis)...") |
| | borderless_text, borderless_data = extract_text_borderless(result) |
| |
|
| | |
| | print("Running block-geometry table detection (Method 4: docTR block analysis)...") |
| | block_geo_text, block_geo_data = extract_text_block_geometry(result) |
| |
|
| | |
| | |
| | if two_stage_data.get('is_table'): |
| | display_text = two_stage_text |
| | primary_table_data = two_stage_data |
| | print(f"Using Two-stage: {two_stage_data.get('num_rows', 0)}x{two_stage_data.get('num_columns', 0)} table") |
| | elif table_data.get('is_table'): |
| | display_text = table_formatted_text |
| | primary_table_data = table_data |
| | print(f"Using img2table: {table_data.get('num_rows', 0)}x{table_data.get('num_columns', 0)} table") |
| | elif borderless_data.get('is_table'): |
| | display_text = borderless_text |
| | primary_table_data = borderless_data |
| | print(f"Using Borderless: {borderless_data.get('num_rows', 0)}x{borderless_data.get('num_columns', 0)} table") |
| | elif block_geo_data.get('is_table'): |
| | display_text = block_geo_text |
| | primary_table_data = block_geo_data |
| | print(f"Using Block-geometry: {block_geo_data.get('num_rows', 0)}x{block_geo_data.get('num_columns', 0)} table") |
| | else: |
| | display_text = structured_text |
| | primary_table_data = {'is_table': False} |
| | print("No table detected by any method, using regular OCR text") |
| |
|
| | |
| | correction_enabled = enable_correction.lower() == "true" |
| | correction_result = {'corrected_text': cleaned_text, 'corrections': []} |
| |
|
| | |
| | ner_input_text = cleaned_text |
| |
|
| | |
| | print("Running NER...") |
| | entities = ner_pipeline(ner_input_text) |
| |
|
| | |
| | structured_entities = [] |
| | for entity in entities: |
| | if entity.get('score', 0.0) > 0.1: |
| | structured_entities.append({ |
| | 'entity_group': entity['entity_group'], |
| | 'score': float(entity['score']), |
| | 'word': entity['word'].strip(), |
| | }) |
| |
|
| | |
| | entities_with_boxes = map_entities_to_boxes(structured_entities, words_with_boxes, ner_input_text) |
| |
|
| | |
| | if correction_enabled: |
| | ner_corr = correct_with_ner_entities( |
| | words_with_boxes, structured_entities, |
| | correction_result['corrected_text'], confidence_threshold=float(correction_threshold)) |
| | if ner_corr['corrections']: |
| | correction_result['corrections'].extend(ner_corr['corrections']) |
| | correction_result['corrected_text'] = ner_corr['corrected_text'] |
| | print(f"NER-informed correction: {len(ner_corr['corrections'])} additional fix(es)") |
| |
|
| | |
| | detected_drugs = [] |
| | for entity in structured_entities: |
| | if entity['entity_group'] in ['CHEM', 'CHEMICAL', 'TREATMENT', 'MEDICATION']: |
| | detected_drugs.append(entity['word']) |
| |
|
| | interactions = check_drug_interactions(detected_drugs) if detected_drugs else [] |
| | print(f"Found {len(interactions)} drug interactions") |
| |
|
| | |
| | lab_anomalies = check_lab_values(structured_text, primary_table_data, words_with_boxes) |
| | print(f"Found {len(lab_anomalies)} lab values ({sum(1 for a in lab_anomalies if a['status'] != 'normal')} abnormal)") |
| |
|
| | return { |
| | "structured_text": display_text, |
| | "cleaned_text": cleaned_text, |
| | "corrected_text": correction_result['corrected_text'] if correction_enabled else None, |
| | "corrections": correction_result['corrections'] if correction_enabled else [], |
| | "medical_entities": entities_with_boxes, |
| | "interactions": interactions, |
| | "lab_anomalies": lab_anomalies, |
| | "model_id": NER_MODELS[ner_model_id]["name"], |
| | "ocr_model": f"{det_arch} + {reco_arch}", |
| | "image_width": img_width, |
| | "image_height": img_height, |
| | "synthesized_image": synthesized_image, |
| | |
| | "table_detected": primary_table_data.get('is_table', False), |
| | "table_data": { |
| | "num_columns": primary_table_data.get('num_columns', 0), |
| | "num_rows": primary_table_data.get('num_rows', 0), |
| | "cells": primary_table_data.get('cells', []), |
| | "method": primary_table_data.get('method', 'unknown') |
| | } if primary_table_data.get('is_table') else None, |
| | |
| | "table_comparison": { |
| | "method1_img2table": { |
| | "name": "img2table (line detection + integrated OCR)", |
| | "detected": table_data.get('is_table', False), |
| | "num_columns": table_data.get('num_columns', 0), |
| | "num_rows": table_data.get('num_rows', 0), |
| | "cells": table_data.get('cells', []), |
| | "formatted_text": table_formatted_text if table_data.get('is_table') else None |
| | }, |
| | "method2_two_stage": { |
| | "name": "Two-stage (structure detection + cell-by-cell OCR)", |
| | "detected": two_stage_data.get('is_table', False), |
| | "num_columns": two_stage_data.get('num_columns', 0), |
| | "num_rows": two_stage_data.get('num_rows', 0), |
| | "cells": two_stage_data.get('cells', []), |
| | "formatted_text": two_stage_text if two_stage_data.get('is_table') else None |
| | }, |
| | "method3_borderless": { |
| | "name": "Borderless (text position clustering)", |
| | "detected": borderless_data.get('is_table', False), |
| | "num_columns": borderless_data.get('num_columns', 0), |
| | "num_rows": borderless_data.get('num_rows', 0), |
| | "cells": borderless_data.get('cells', []), |
| | "formatted_text": borderless_text if borderless_data.get('is_table') else None, |
| | "fill_ratio": borderless_data.get('fill_ratio', 0) |
| | }, |
| | "method4_block_geometry": { |
| | "name": "Block-geometry (docTR block grouping)", |
| | "detected": block_geo_data.get('is_table', False), |
| | "num_columns": block_geo_data.get('num_columns', 0), |
| | "num_rows": block_geo_data.get('num_rows', 0), |
| | "cells": block_geo_data.get('cells', []), |
| | "formatted_text": block_geo_text if block_geo_data.get('is_table') else None, |
| | "fill_ratio": block_geo_data.get('fill_ratio', 0) |
| | } |
| | }, |
| | |
| | "docling_result": { |
| | "available": docling_result.get("success", False), |
| | "markdown_text": docling_result.get("markdown_text", ""), |
| | "plain_text": docling_result.get("plain_text", ""), |
| | "table_detected": bool(docling_result.get("tables")), |
| | "table_data": docling_result.get("primary_table"), |
| | "error": docling_result.get("error"), |
| | } if docling_result else { |
| | "available": False, |
| | "error": "Docling pipeline did not run", |
| | } |
| | } |
| |
|
| | except Exception as e: |
| | print(f"Processing error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return JSONResponse( |
| | status_code=500, |
| | content={"detail": f"An error occurred during processing: {str(e)}"} |
| | ) |
| |
|