hanz245 commited on
Commit
41cfb0b
·
1 Parent(s): c7a86ba
Files changed (1) hide show
  1. template_matcher.py +65 -14
template_matcher.py CHANGED
@@ -93,11 +93,51 @@ def _get_crnn():
93
  return _crnn_ocr
94
 
95
 
96
- def _crnn_read(crop_img: Image.Image) -> str:
97
- """Run CRNN+CTC on a single PIL Image crop and return decoded text."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ocr = _get_crnn()
99
- if ocr is None or _crnn_decode is None:
100
- return ''
101
  try:
102
  import torch
103
 
@@ -111,23 +151,24 @@ def _crnn_read(crop_img: Image.Image) -> str:
111
  with torch.no_grad():
112
  outputs = ocr.model(tensor)
113
 
114
- decoded = _crnn_decode(outputs.cpu(), ocr.idx_to_char, method='greedy')
115
- return decoded[0].strip()
116
  except Exception as e:
117
  print(f'[template_matcher] CRNN+CTC read error: {e}')
118
- return ''
119
 
120
 
121
  def _crnn_read_batch(crops: list) -> list:
122
  """
123
  Run CRNN+CTC on a list of PIL Image crops in one forward pass.
 
124
  """
125
  if not crops:
126
  return []
127
 
128
  ocr = _get_crnn()
129
- if ocr is None or _crnn_decode is None:
130
- return [''] * len(crops)
131
 
132
  try:
133
  import torch
@@ -147,8 +188,7 @@ def _crnn_read_batch(crops: list) -> list:
147
  with torch.no_grad():
148
  outputs = ocr.model(batch)
149
 
150
- decoded = _crnn_decode(outputs.cpu(), ocr.idx_to_char, method='greedy')
151
- return [d.strip() for d in decoded]
152
 
153
  except Exception as e:
154
  print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
@@ -1285,7 +1325,8 @@ def detect_form_type(image_path: str) -> str:
1285
  img_l = Image.open(image_path).convert('L')
1286
  w, h = img_l.size
1287
  title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
1288
- title = _crnn_read(title_crop).upper()
 
1289
 
1290
  if title:
1291
  if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
@@ -1365,12 +1406,21 @@ def extract_fields(image_path: str, form_type: str = None):
1365
  assist_text = _paddle_read(crop)
1366
  assist_texts.append(assist_text)
1367
 
1368
- crnn_texts = _crnn_read_batch(crops)
 
 
 
 
 
 
 
 
 
1369
 
1370
- for field_name, crnn_text, assist_text in zip(field_names, crnn_texts, assist_texts):
1371
  final_text = _smart_merge(field_name, crnn_text, assist_text)
1372
  if final_text:
1373
  fields[field_name] = final_text
 
1374
 
1375
  print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
1376
  paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
@@ -1382,6 +1432,7 @@ def extract_fields(image_path: str, form_type: str = None):
1382
 
1383
  fields['_quality'] = quality
1384
  fields['_corrections'] = corrections
 
1385
  return fields
1386
  except Exception as e:
1387
  print(f'[template_matcher] extract_fields error: {e}')
 
93
  return _crnn_ocr
94
 
95
 
96
+ def _decode_ctc_with_confidence(outputs, idx_to_char) -> list:
97
+ """
98
+ Decode CRNN+CTC logits and compute per-field confidence.
99
+
100
+ Confidence is the mean probability of the kept non-blank characters after
101
+ greedy CTC collapse. This gives a real OCR confidence from the CRNN logits.
102
+
103
+ Assumption: CTC blank index is 0. If your model uses a different blank
104
+ index, change blank_idx below.
105
+ """
106
+ import torch
107
+
108
+ blank_idx = 0
109
+ probs = torch.softmax(outputs, dim=2)
110
+ max_probs, preds = torch.max(probs, dim=2)
111
+
112
+ results = []
113
+ for pred_seq, prob_seq in zip(preds, max_probs):
114
+ text_chars = []
115
+ char_probs = []
116
+ prev_idx = None
117
+
118
+ for idx, prob in zip(pred_seq.tolist(), prob_seq.tolist()):
119
+ if idx != blank_idx and idx != prev_idx:
120
+ char = idx_to_char.get(idx, '')
121
+ if char:
122
+ text_chars.append(char)
123
+ char_probs.append(float(prob))
124
+ prev_idx = idx
125
+
126
+ text = ''.join(text_chars).strip()
127
+ confidence = sum(char_probs) / len(char_probs) if char_probs else 0.0
128
+ results.append({
129
+ 'text': text,
130
+ 'confidence': float(confidence),
131
+ })
132
+
133
+ return results
134
+
135
+
136
+ def _crnn_read(crop_img: Image.Image) -> dict:
137
+ """Run CRNN+CTC on a single PIL Image crop and return text + confidence."""
138
  ocr = _get_crnn()
139
+ if ocr is None:
140
+ return {'text': '', 'confidence': 0.0}
141
  try:
142
  import torch
143
 
 
151
  with torch.no_grad():
152
  outputs = ocr.model(tensor)
153
 
154
+ decoded = _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
155
+ return decoded[0] if decoded else {'text': '', 'confidence': 0.0}
156
  except Exception as e:
157
  print(f'[template_matcher] CRNN+CTC read error: {e}')
158
+ return {'text': '', 'confidence': 0.0}
159
 
160
 
161
  def _crnn_read_batch(crops: list) -> list:
162
  """
163
  Run CRNN+CTC on a list of PIL Image crops in one forward pass.
164
+ Returns a list of {'text': str, 'confidence': float}.
165
  """
166
  if not crops:
167
  return []
168
 
169
  ocr = _get_crnn()
170
+ if ocr is None:
171
+ return [{'text': '', 'confidence': 0.0} for _ in crops]
172
 
173
  try:
174
  import torch
 
188
  with torch.no_grad():
189
  outputs = ocr.model(batch)
190
 
191
+ return _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
 
192
 
193
  except Exception as e:
194
  print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
 
1325
  img_l = Image.open(image_path).convert('L')
1326
  w, h = img_l.size
1327
  title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
1328
+ title_result = _crnn_read(title_crop)
1329
+ title = (title_result.get('text', '') if isinstance(title_result, dict) else str(title_result)).upper()
1330
 
1331
  if title:
1332
  if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
 
1406
  assist_text = _paddle_read(crop)
1407
  assist_texts.append(assist_text)
1408
 
1409
+ crnn_results = _crnn_read_batch(crops)
1410
+ confidence = {}
1411
+
1412
+ for field_name, crnn_res, assist_text in zip(field_names, crnn_results, assist_texts):
1413
+ if isinstance(crnn_res, dict):
1414
+ crnn_text = crnn_res.get('text', '')
1415
+ crnn_conf = float(crnn_res.get('confidence', 0.0) or 0.0)
1416
+ else:
1417
+ crnn_text = str(crnn_res or '')
1418
+ crnn_conf = 0.0
1419
 
 
1420
  final_text = _smart_merge(field_name, crnn_text, assist_text)
1421
  if final_text:
1422
  fields[field_name] = final_text
1423
+ confidence[field_name] = crnn_conf
1424
 
1425
  print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
1426
  paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
 
1432
 
1433
  fields['_quality'] = quality
1434
  fields['_corrections'] = corrections
1435
+ fields['_confidence'] = confidence
1436
  return fields
1437
  except Exception as e:
1438
  print(f'[template_matcher] extract_fields error: {e}')