hkai20000 commited on
Commit
175eb27
·
verified ·
1 Parent(s): ee0bd33

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +132 -117
main.py CHANGED
@@ -116,11 +116,6 @@ NER_MODELS = {
116
  # --- GLOBAL MODEL CACHES ---
117
  ner_model_cache: Dict[str, Any] = {}
118
  ocr_model_cache: Dict[str, Any] = {}
119
- mlm_corrector_cache: Dict[str, Any] = {}
120
-
121
- # --- OCR CORRECTION MODEL ---
122
- OCR_CORRECTION_MODEL = "hkai20000/bio-clinicalbert-ocr-correction"
123
-
124
  # --- DOCLING CONVERTER CACHE ---
125
  docling_converter_cache: Dict[str, Any] = {}
126
 
@@ -312,122 +307,142 @@ def get_ner_pipeline(model_id: str):
312
  return None
313
 
314
 
315
- # --- OCR CORRECTION MODEL LOADING ---
316
- def get_mlm_corrector():
317
- """Lazy-load the fill-mask pipeline for OCR error correction."""
318
- if OCR_CORRECTION_MODEL in mlm_corrector_cache:
319
- return mlm_corrector_cache[OCR_CORRECTION_MODEL]
 
320
 
321
- try:
322
- print(f"Loading OCR correction model: {OCR_CORRECTION_MODEL}...")
323
- corrector = hf_pipeline("fill-mask", model=OCR_CORRECTION_MODEL)
324
- mlm_corrector_cache[OCR_CORRECTION_MODEL] = corrector
325
- print(f"OCR correction model loaded successfully!")
326
- return corrector
327
- except Exception as e:
328
- print(f"ERROR: Failed to load OCR correction model: {e}")
329
- return None
 
330
 
331
 
332
- def correct_ocr_text(words_with_boxes: list, cleaned_text: str, confidence_threshold: float = 0.75) -> dict:
333
- """
334
- Correct OCR errors using fill-mask MLM model.
335
 
336
- For each word with docTR confidence below the threshold:
337
- 1. Mask the word in the full text context
338
- 2. Run fill-mask to get predictions
339
- 3. Accept correction if MLM confidence > 0.5 and edit distance <= 3
340
 
341
- Returns dict with 'corrected_text' and 'corrections' list.
342
- """
343
- corrector = get_mlm_corrector()
344
- if corrector is None:
345
- return {'corrected_text': cleaned_text, 'corrections': []}
346
 
347
- corrections = []
348
- corrected_text = cleaned_text
349
-
350
- low_confidence_words = [
351
- w for w in words_with_boxes
352
- if w.get('confidence', 1.0) < confidence_threshold
353
- and len(w['word']) >= 4
354
- and w['word'].isalpha() # Only correct purely alphabetic words — skip numbers, units, punctuation
355
- ]
356
-
357
- if not low_confidence_words:
358
- return {'corrected_text': cleaned_text, 'corrections': []}
359
-
360
- for word_info in low_confidence_words:
361
- original_word = word_info['word']
362
- word_confidence = word_info.get('confidence', 0.0)
363
-
364
- pattern = re.escape(original_word)
365
- match = re.search(r'\b' + pattern + r'\b', corrected_text)
366
- if not match:
367
- match = re.search(pattern, corrected_text)
368
- if not match:
369
- continue
370
 
371
- start, end = match.start(), match.end()
372
- masked_text = corrected_text[:start] + "[MASK]" + corrected_text[end:]
 
 
 
 
373
 
374
- mask_pos = masked_text.find("[MASK]")
375
- context_chars = 200
376
- ctx_start = max(0, mask_pos - context_chars)
377
- ctx_end = min(len(masked_text), mask_pos + context_chars)
378
- context = masked_text[ctx_start:ctx_end]
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
- if "[MASK]" not in context:
381
- continue
382
 
383
- try:
384
- predictions = corrector(context, top_k=5)
385
- except Exception as e:
386
- print(f"MLM correction error for '{original_word}': {e}")
 
387
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- if not predictions:
 
 
 
 
390
  continue
391
 
392
- top = predictions[0]
393
- predicted_word = top['token_str'].strip()
394
- mlm_score = top['score']
 
395
 
396
- edit_dist = _edit_distance(original_word.lower(), predicted_word.lower())
 
 
397
 
398
- if mlm_score > 0.5 and edit_dist <= 3 and predicted_word.lower() != original_word.lower():
399
- corrected_text = corrected_text[:start] + predicted_word + corrected_text[end:]
400
- corrections.append({
401
- 'original': original_word,
402
- 'corrected': predicted_word,
403
- 'confidence': round(mlm_score, 4),
404
- 'ocr_confidence': round(word_confidence, 4),
405
- 'edit_distance': edit_dist,
406
- })
407
 
408
- return {
409
- 'corrected_text': corrected_text,
410
- 'corrections': corrections,
411
- }
 
412
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- def _edit_distance(s1: str, s2: str) -> int:
415
- """Compute Levenshtein edit distance between two strings."""
416
- if len(s1) < len(s2):
417
- return _edit_distance(s2, s1)
418
- if len(s2) == 0:
419
- return len(s1)
420
 
421
- prev_row = range(len(s2) + 1)
422
- for i, c1 in enumerate(s1):
423
- curr_row = [i + 1]
424
- for j, c2 in enumerate(s2):
425
- insertions = prev_row[j + 1] + 1
426
- deletions = curr_row[j] + 1
427
- substitutions = prev_row[j] + (c1 != c2)
428
- curr_row.append(min(insertions, deletions, substitutions))
429
- prev_row = curr_row
430
- return prev_row[-1]
431
 
432
  # --- IMAGE PREPROCESSING ---
433
  def deskew_image(image: np.ndarray) -> np.ndarray:
@@ -2136,9 +2151,9 @@ async def get_available_models():
2136
  for model_id, model_data in NER_MODELS.items()
2137
  },
2138
  "ocr_correction_model": {
2139
- "id": OCR_CORRECTION_MODEL,
2140
- "name": "Bio-ClinicalBERT OCR Correction",
2141
- "description": "Fine-tuned Bio_ClinicalBERT for medical OCR error correction using fill-mask MLM",
2142
  }
2143
  }
2144
 
@@ -2264,22 +2279,12 @@ async def process_image(
2264
  primary_table_data = {'is_table': False}
2265
  print("No table detected by any method, using regular OCR text")
2266
 
2267
- # OCR Text Correction (if enabled)
2268
  correction_enabled = enable_correction.lower() == "true"
2269
  correction_result = {'corrected_text': cleaned_text, 'corrections': []}
2270
 
2271
- if correction_enabled:
2272
- print(f"Running OCR text correction with threshold={correction_threshold}...")
2273
- correction_result = correct_ocr_text(words_with_boxes, cleaned_text, confidence_threshold=float(correction_threshold))
2274
- if correction_result['corrections']:
2275
- print(f"Applied {len(correction_result['corrections'])} corrections")
2276
- for c in correction_result['corrections']:
2277
- print(f" '{c['original']}' -> '{c['corrected']}' (MLM={c['confidence']:.2f})")
2278
- else:
2279
- print("No corrections needed")
2280
-
2281
- # Use corrected text for NER if correction was applied
2282
- ner_input_text = correction_result['corrected_text'] if correction_enabled else cleaned_text
2283
 
2284
  # Perform NER on text
2285
  print("Running NER...")
@@ -2298,6 +2303,16 @@ async def process_image(
2298
  # Map entities to bounding boxes
2299
  entities_with_boxes = map_entities_to_boxes(structured_entities, words_with_boxes, ner_input_text)
2300
 
 
 
 
 
 
 
 
 
 
 
2301
  # Check for drug interactions
2302
  detected_drugs = []
2303
  for entity in structured_entities:
 
116
  # --- GLOBAL MODEL CACHES ---
117
  ner_model_cache: Dict[str, Any] = {}
118
  ocr_model_cache: Dict[str, Any] = {}
 
 
 
 
 
119
  # --- DOCLING CONVERTER CACHE ---
120
  docling_converter_cache: Dict[str, Any] = {}
121
 
 
307
  return None
308
 
309
 
310
+ def _edit_distance(s1: str, s2: str) -> int:
311
+ """Compute Levenshtein edit distance between two strings."""
312
+ if len(s1) < len(s2):
313
+ return _edit_distance(s2, s1)
314
+ if len(s2) == 0:
315
+ return len(s1)
316
 
317
+ prev_row = range(len(s2) + 1)
318
+ for i, c1 in enumerate(s1):
319
+ curr_row = [i + 1]
320
+ for j, c2 in enumerate(s2):
321
+ insertions = prev_row[j + 1] + 1
322
+ deletions = curr_row[j] + 1
323
+ substitutions = prev_row[j] + (c1 != c2)
324
+ curr_row.append(min(insertions, deletions, substitutions))
325
+ prev_row = curr_row
326
+ return prev_row[-1]
327
 
328
 
329
+ # --- NER-INFORMED CORRECTION ---
 
 
330
 
331
+ _entity_dicts: dict[str, set] = {}
 
 
 
332
 
 
 
 
 
 
333
 
334
+ def _build_entity_dicts():
335
+ """Build per-entity-type dictionaries from already-loaded DRUG_INTERACTIONS and MEDLINEPLUS_MAP."""
336
+ global _entity_dicts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ med_dict: set[str] = set()
339
+ for drug_name in DRUG_INTERACTIONS.keys():
340
+ for part in str(drug_name).split(','):
341
+ part = part.strip().lower()
342
+ if len(part) >= 4:
343
+ med_dict.add(part)
344
 
345
+ lab_dict: set[str] = set()
346
+ for test_name, data in MEDLINEPLUS_MAP.items():
347
+ if len(test_name) >= 4:
348
+ lab_dict.add(test_name.lower())
349
+ for alias in data.get('aliases', []):
350
+ if len(alias) >= 4:
351
+ lab_dict.add(alias.lower())
352
+
353
+ _entity_dicts = {
354
+ 'MEDICATION': med_dict,
355
+ 'LAB_VALUE': lab_dict,
356
+ 'DIAGNOSTIC_PROCEDURE': lab_dict,
357
+ 'TREATMENT': med_dict,
358
+ 'CHEM': med_dict,
359
+ 'CHEMICAL': med_dict,
360
+ }
361
+ print(f"Entity dicts built: {len(med_dict)} medication terms, {len(lab_dict)} lab terms")
362
 
 
 
363
 
364
+ def _find_closest(word: str, dictionary: set) -> tuple:
365
+ best_match, best_dist = None, 999
366
+ word_lower = word.lower()
367
+ for term in dictionary:
368
+ if abs(len(term) - len(word_lower)) > 3:
369
  continue
370
+ dist = _edit_distance(word_lower, term)
371
+ if dist < best_dist:
372
+ best_dist = dist
373
+ best_match = term
374
+ return best_match, best_dist
375
+
376
+
377
+ def _match_case(original: str, replacement: str) -> str:
378
+ if original.isupper():
379
+ return replacement.upper()
380
+ if original[0].isupper():
381
+ return replacement.capitalize()
382
+ return replacement.lower()
383
+
384
+
385
+ def correct_with_ner_entities(
386
+ words_with_boxes: list,
387
+ ner_entities: list,
388
+ text: str,
389
+ confidence_threshold: float = 0.75,
390
+ ) -> dict:
391
+ """Second-pass correction using NER entity labels as context."""
392
+ if not _entity_dicts:
393
+ _build_entity_dicts()
394
+
395
+ word_conf: dict[str, float] = {}
396
+ for w in words_with_boxes:
397
+ key = w['word'].lower()
398
+ word_conf[key] = min(word_conf.get(key, 1.0), w.get('confidence', 1.0))
399
+
400
+ corrections = []
401
+ corrected_text = text
402
 
403
+ for entity in ner_entities:
404
+ entity_type = entity.get('entity_group', '')
405
+ entity_word = entity.get('word', '').strip()
406
+ lookup_dict = _entity_dicts.get(entity_type)
407
+ if not lookup_dict or not entity_word:
408
  continue
409
 
410
+ for token in entity_word.split():
411
+ clean_token = re.sub(r'[^a-zA-Z]', '', token)
412
+ if not clean_token.isalpha() or len(clean_token) < 4:
413
+ continue
414
 
415
+ ocr_conf = word_conf.get(clean_token.lower(), 1.0)
416
+ if ocr_conf >= confidence_threshold:
417
+ continue
418
 
419
+ best_match, best_dist = _find_closest(clean_token, lookup_dict)
420
+ if best_match is None or best_dist > 2:
421
+ continue
422
+ if best_match.lower() == clean_token.lower():
423
+ continue
 
 
 
 
424
 
425
+ replacement = _match_case(clean_token, best_match)
426
+ match = re.search(r'\b' + re.escape(clean_token) + r'\b',
427
+ corrected_text, re.IGNORECASE)
428
+ if not match:
429
+ continue
430
 
431
+ start, end = match.start(), match.end()
432
+ corrected_text = corrected_text[:start] + replacement + corrected_text[end:]
433
+ corrections.append({
434
+ 'original': clean_token,
435
+ 'corrected': replacement,
436
+ 'confidence': round(1.0 - best_dist / max(len(clean_token), len(best_match)), 4),
437
+ 'ocr_confidence': round(ocr_conf, 4),
438
+ 'edit_distance': best_dist,
439
+ 'source': 'ner',
440
+ 'entity_type': entity_type,
441
+ })
442
+ word_conf[replacement.lower()] = 1.0
443
 
444
+ return {'corrected_text': corrected_text, 'corrections': corrections}
 
 
 
 
 
445
 
 
 
 
 
 
 
 
 
 
 
446
 
447
  # --- IMAGE PREPROCESSING ---
448
  def deskew_image(image: np.ndarray) -> np.ndarray:
 
2151
  for model_id, model_data in NER_MODELS.items()
2152
  },
2153
  "ocr_correction_model": {
2154
+ "id": "ner-dictionary",
2155
+ "name": "NER-Informed Dictionary Correction",
2156
+ "description": "Edit-distance correction against medical entity dictionaries, guided by NER entity labels",
2157
  }
2158
  }
2159
 
 
2279
  primary_table_data = {'is_table': False}
2280
  print("No table detected by any method, using regular OCR text")
2281
 
2282
+ # OCR Text Correction (NER-informed dictionary pass)
2283
  correction_enabled = enable_correction.lower() == "true"
2284
  correction_result = {'corrected_text': cleaned_text, 'corrections': []}
2285
 
2286
+ # Use cleaned text for NER input (NER correction runs after NER, see below)
2287
+ ner_input_text = cleaned_text
 
 
 
 
 
 
 
 
 
 
2288
 
2289
  # Perform NER on text
2290
  print("Running NER...")
 
2303
  # Map entities to bounding boxes
2304
  entities_with_boxes = map_entities_to_boxes(structured_entities, words_with_boxes, ner_input_text)
2305
 
2306
+ # NER-informed correction (second pass: fix low-confidence tokens matching entity dicts)
2307
+ if correction_enabled:
2308
+ ner_corr = correct_with_ner_entities(
2309
+ words_with_boxes, structured_entities,
2310
+ correction_result['corrected_text'], confidence_threshold=float(correction_threshold))
2311
+ if ner_corr['corrections']:
2312
+ correction_result['corrections'].extend(ner_corr['corrections'])
2313
+ correction_result['corrected_text'] = ner_corr['corrected_text']
2314
+ print(f"NER-informed correction: {len(ner_corr['corrections'])} additional fix(es)")
2315
+
2316
  # Check for drug interactions
2317
  detected_drugs = []
2318
  for entity in structured_entities: