update
Browse files- template_matcher.py +65 -14
template_matcher.py
CHANGED
|
@@ -93,11 +93,51 @@ def _get_crnn():
|
|
| 93 |
return _crnn_ocr
|
| 94 |
|
| 95 |
|
| 96 |
-
def
|
| 97 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
ocr = _get_crnn()
|
| 99 |
-
if ocr is None
|
| 100 |
-
return ''
|
| 101 |
try:
|
| 102 |
import torch
|
| 103 |
|
|
@@ -111,23 +151,24 @@ def _crnn_read(crop_img: Image.Image) -> str:
|
|
| 111 |
with torch.no_grad():
|
| 112 |
outputs = ocr.model(tensor)
|
| 113 |
|
| 114 |
-
decoded =
|
| 115 |
-
return decoded[0].
|
| 116 |
except Exception as e:
|
| 117 |
print(f'[template_matcher] CRNN+CTC read error: {e}')
|
| 118 |
-
return ''
|
| 119 |
|
| 120 |
|
| 121 |
def _crnn_read_batch(crops: list) -> list:
|
| 122 |
"""
|
| 123 |
Run CRNN+CTC on a list of PIL Image crops in one forward pass.
|
|
|
|
| 124 |
"""
|
| 125 |
if not crops:
|
| 126 |
return []
|
| 127 |
|
| 128 |
ocr = _get_crnn()
|
| 129 |
-
if ocr is None
|
| 130 |
-
return [''
|
| 131 |
|
| 132 |
try:
|
| 133 |
import torch
|
|
@@ -147,8 +188,7 @@ def _crnn_read_batch(crops: list) -> list:
|
|
| 147 |
with torch.no_grad():
|
| 148 |
outputs = ocr.model(batch)
|
| 149 |
|
| 150 |
-
|
| 151 |
-
return [d.strip() for d in decoded]
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
|
|
@@ -1285,7 +1325,8 @@ def detect_form_type(image_path: str) -> str:
|
|
| 1285 |
img_l = Image.open(image_path).convert('L')
|
| 1286 |
w, h = img_l.size
|
| 1287 |
title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
|
| 1288 |
-
|
|
|
|
| 1289 |
|
| 1290 |
if title:
|
| 1291 |
if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
|
|
@@ -1365,12 +1406,21 @@ def extract_fields(image_path: str, form_type: str = None):
|
|
| 1365 |
assist_text = _paddle_read(crop)
|
| 1366 |
assist_texts.append(assist_text)
|
| 1367 |
|
| 1368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1369 |
|
| 1370 |
-
for field_name, crnn_text, assist_text in zip(field_names, crnn_texts, assist_texts):
|
| 1371 |
final_text = _smart_merge(field_name, crnn_text, assist_text)
|
| 1372 |
if final_text:
|
| 1373 |
fields[field_name] = final_text
|
|
|
|
| 1374 |
|
| 1375 |
print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
|
| 1376 |
paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
|
|
@@ -1382,6 +1432,7 @@ def extract_fields(image_path: str, form_type: str = None):
|
|
| 1382 |
|
| 1383 |
fields['_quality'] = quality
|
| 1384 |
fields['_corrections'] = corrections
|
|
|
|
| 1385 |
return fields
|
| 1386 |
except Exception as e:
|
| 1387 |
print(f'[template_matcher] extract_fields error: {e}')
|
|
|
|
| 93 |
return _crnn_ocr
|
| 94 |
|
| 95 |
|
| 96 |
+
def _decode_ctc_with_confidence(outputs, idx_to_char) -> list:
|
| 97 |
+
"""
|
| 98 |
+
Decode CRNN+CTC logits and compute per-field confidence.
|
| 99 |
+
|
| 100 |
+
Confidence is the mean probability of the kept non-blank characters after
|
| 101 |
+
greedy CTC collapse. This gives a real OCR confidence from the CRNN logits.
|
| 102 |
+
|
| 103 |
+
Assumption: CTC blank index is 0. If your model uses a different blank
|
| 104 |
+
index, change blank_idx below.
|
| 105 |
+
"""
|
| 106 |
+
import torch
|
| 107 |
+
|
| 108 |
+
blank_idx = 0
|
| 109 |
+
probs = torch.softmax(outputs, dim=2)
|
| 110 |
+
max_probs, preds = torch.max(probs, dim=2)
|
| 111 |
+
|
| 112 |
+
results = []
|
| 113 |
+
for pred_seq, prob_seq in zip(preds, max_probs):
|
| 114 |
+
text_chars = []
|
| 115 |
+
char_probs = []
|
| 116 |
+
prev_idx = None
|
| 117 |
+
|
| 118 |
+
for idx, prob in zip(pred_seq.tolist(), prob_seq.tolist()):
|
| 119 |
+
if idx != blank_idx and idx != prev_idx:
|
| 120 |
+
char = idx_to_char.get(idx, '')
|
| 121 |
+
if char:
|
| 122 |
+
text_chars.append(char)
|
| 123 |
+
char_probs.append(float(prob))
|
| 124 |
+
prev_idx = idx
|
| 125 |
+
|
| 126 |
+
text = ''.join(text_chars).strip()
|
| 127 |
+
confidence = sum(char_probs) / len(char_probs) if char_probs else 0.0
|
| 128 |
+
results.append({
|
| 129 |
+
'text': text,
|
| 130 |
+
'confidence': float(confidence),
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _crnn_read(crop_img: Image.Image) -> dict:
|
| 137 |
+
"""Run CRNN+CTC on a single PIL Image crop and return text + confidence."""
|
| 138 |
ocr = _get_crnn()
|
| 139 |
+
if ocr is None:
|
| 140 |
+
return {'text': '', 'confidence': 0.0}
|
| 141 |
try:
|
| 142 |
import torch
|
| 143 |
|
|
|
|
| 151 |
with torch.no_grad():
|
| 152 |
outputs = ocr.model(tensor)
|
| 153 |
|
| 154 |
+
decoded = _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
|
| 155 |
+
return decoded[0] if decoded else {'text': '', 'confidence': 0.0}
|
| 156 |
except Exception as e:
|
| 157 |
print(f'[template_matcher] CRNN+CTC read error: {e}')
|
| 158 |
+
return {'text': '', 'confidence': 0.0}
|
| 159 |
|
| 160 |
|
| 161 |
def _crnn_read_batch(crops: list) -> list:
|
| 162 |
"""
|
| 163 |
Run CRNN+CTC on a list of PIL Image crops in one forward pass.
|
| 164 |
+
Returns a list of {'text': str, 'confidence': float}.
|
| 165 |
"""
|
| 166 |
if not crops:
|
| 167 |
return []
|
| 168 |
|
| 169 |
ocr = _get_crnn()
|
| 170 |
+
if ocr is None:
|
| 171 |
+
return [{'text': '', 'confidence': 0.0} for _ in crops]
|
| 172 |
|
| 173 |
try:
|
| 174 |
import torch
|
|
|
|
| 188 |
with torch.no_grad():
|
| 189 |
outputs = ocr.model(batch)
|
| 190 |
|
| 191 |
+
return _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
|
|
|
|
| 192 |
|
| 193 |
except Exception as e:
|
| 194 |
print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
|
|
|
|
| 1325 |
img_l = Image.open(image_path).convert('L')
|
| 1326 |
w, h = img_l.size
|
| 1327 |
title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
|
| 1328 |
+
title_result = _crnn_read(title_crop)
|
| 1329 |
+
title = (title_result.get('text', '') if isinstance(title_result, dict) else str(title_result)).upper()
|
| 1330 |
|
| 1331 |
if title:
|
| 1332 |
if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
|
|
|
|
| 1406 |
assist_text = _paddle_read(crop)
|
| 1407 |
assist_texts.append(assist_text)
|
| 1408 |
|
| 1409 |
+
crnn_results = _crnn_read_batch(crops)
|
| 1410 |
+
confidence = {}
|
| 1411 |
+
|
| 1412 |
+
for field_name, crnn_res, assist_text in zip(field_names, crnn_results, assist_texts):
|
| 1413 |
+
if isinstance(crnn_res, dict):
|
| 1414 |
+
crnn_text = crnn_res.get('text', '')
|
| 1415 |
+
crnn_conf = float(crnn_res.get('confidence', 0.0) or 0.0)
|
| 1416 |
+
else:
|
| 1417 |
+
crnn_text = str(crnn_res or '')
|
| 1418 |
+
crnn_conf = 0.0
|
| 1419 |
|
|
|
|
| 1420 |
final_text = _smart_merge(field_name, crnn_text, assist_text)
|
| 1421 |
if final_text:
|
| 1422 |
fields[field_name] = final_text
|
| 1423 |
+
confidence[field_name] = crnn_conf
|
| 1424 |
|
| 1425 |
print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
|
| 1426 |
paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
|
|
|
|
| 1432 |
|
| 1433 |
fields['_quality'] = quality
|
| 1434 |
fields['_corrections'] = corrections
|
| 1435 |
+
fields['_confidence'] = confidence
|
| 1436 |
return fields
|
| 1437 |
except Exception as e:
|
| 1438 |
print(f'[template_matcher] extract_fields error: {e}')
|