hanz245 commited on
Commit
0cdc412
·
1 Parent(s): 41cfb0b
Files changed (1) hide show
  1. template_matcher.py +14 -65
template_matcher.py CHANGED
@@ -93,51 +93,11 @@ def _get_crnn():
93
  return _crnn_ocr
94
 
95
 
96
- def _decode_ctc_with_confidence(outputs, idx_to_char) -> list:
97
- """
98
- Decode CRNN+CTC logits and compute per-field confidence.
99
-
100
- Confidence is the mean probability of the kept non-blank characters after
101
- greedy CTC collapse. This gives a real OCR confidence from the CRNN logits.
102
-
103
- Assumption: CTC blank index is 0. If your model uses a different blank
104
- index, change blank_idx below.
105
- """
106
- import torch
107
-
108
- blank_idx = 0
109
- probs = torch.softmax(outputs, dim=2)
110
- max_probs, preds = torch.max(probs, dim=2)
111
-
112
- results = []
113
- for pred_seq, prob_seq in zip(preds, max_probs):
114
- text_chars = []
115
- char_probs = []
116
- prev_idx = None
117
-
118
- for idx, prob in zip(pred_seq.tolist(), prob_seq.tolist()):
119
- if idx != blank_idx and idx != prev_idx:
120
- char = idx_to_char.get(idx, '')
121
- if char:
122
- text_chars.append(char)
123
- char_probs.append(float(prob))
124
- prev_idx = idx
125
-
126
- text = ''.join(text_chars).strip()
127
- confidence = sum(char_probs) / len(char_probs) if char_probs else 0.0
128
- results.append({
129
- 'text': text,
130
- 'confidence': float(confidence),
131
- })
132
-
133
- return results
134
-
135
-
136
- def _crnn_read(crop_img: Image.Image) -> dict:
137
- """Run CRNN+CTC on a single PIL Image crop and return text + confidence."""
138
  ocr = _get_crnn()
139
- if ocr is None:
140
- return {'text': '', 'confidence': 0.0}
141
  try:
142
  import torch
143
 
@@ -151,24 +111,23 @@ def _crnn_read(crop_img: Image.Image) -> dict:
151
  with torch.no_grad():
152
  outputs = ocr.model(tensor)
153
 
154
- decoded = _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
155
- return decoded[0] if decoded else {'text': '', 'confidence': 0.0}
156
  except Exception as e:
157
  print(f'[template_matcher] CRNN+CTC read error: {e}')
158
- return {'text': '', 'confidence': 0.0}
159
 
160
 
161
  def _crnn_read_batch(crops: list) -> list:
162
  """
163
  Run CRNN+CTC on a list of PIL Image crops in one forward pass.
164
- Returns a list of {'text': str, 'confidence': float}.
165
  """
166
  if not crops:
167
  return []
168
 
169
  ocr = _get_crnn()
170
- if ocr is None:
171
- return [{'text': '', 'confidence': 0.0} for _ in crops]
172
 
173
  try:
174
  import torch
@@ -188,7 +147,8 @@ def _crnn_read_batch(crops: list) -> list:
188
  with torch.no_grad():
189
  outputs = ocr.model(batch)
190
 
191
- return _decode_ctc_with_confidence(outputs.cpu(), ocr.idx_to_char)
 
192
 
193
  except Exception as e:
194
  print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
@@ -1325,8 +1285,7 @@ def detect_form_type(image_path: str) -> str:
1325
  img_l = Image.open(image_path).convert('L')
1326
  w, h = img_l.size
1327
  title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
1328
- title_result = _crnn_read(title_crop)
1329
- title = (title_result.get('text', '') if isinstance(title_result, dict) else str(title_result)).upper()
1330
 
1331
  if title:
1332
  if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
@@ -1406,21 +1365,12 @@ def extract_fields(image_path: str, form_type: str = None):
1406
  assist_text = _paddle_read(crop)
1407
  assist_texts.append(assist_text)
1408
 
1409
- crnn_results = _crnn_read_batch(crops)
1410
- confidence = {}
1411
-
1412
- for field_name, crnn_res, assist_text in zip(field_names, crnn_results, assist_texts):
1413
- if isinstance(crnn_res, dict):
1414
- crnn_text = crnn_res.get('text', '')
1415
- crnn_conf = float(crnn_res.get('confidence', 0.0) or 0.0)
1416
- else:
1417
- crnn_text = str(crnn_res or '')
1418
- crnn_conf = 0.0
1419
 
 
1420
  final_text = _smart_merge(field_name, crnn_text, assist_text)
1421
  if final_text:
1422
  fields[field_name] = final_text
1423
- confidence[field_name] = crnn_conf
1424
 
1425
  print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
1426
  paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
@@ -1432,7 +1382,6 @@ def extract_fields(image_path: str, form_type: str = None):
1432
 
1433
  fields['_quality'] = quality
1434
  fields['_corrections'] = corrections
1435
- fields['_confidence'] = confidence
1436
  return fields
1437
  except Exception as e:
1438
  print(f'[template_matcher] extract_fields error: {e}')
 
93
  return _crnn_ocr
94
 
95
 
96
+ def _crnn_read(crop_img: Image.Image) -> str:
97
+ """Run CRNN+CTC on a single PIL Image crop and return decoded text."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ocr = _get_crnn()
99
+ if ocr is None or _crnn_decode is None:
100
+ return ''
101
  try:
102
  import torch
103
 
 
111
  with torch.no_grad():
112
  outputs = ocr.model(tensor)
113
 
114
+ decoded = _crnn_decode(outputs.cpu(), ocr.idx_to_char, method='greedy')
115
+ return decoded[0].strip()
116
  except Exception as e:
117
  print(f'[template_matcher] CRNN+CTC read error: {e}')
118
+ return ''
119
 
120
 
121
  def _crnn_read_batch(crops: list) -> list:
122
  """
123
  Run CRNN+CTC on a list of PIL Image crops in one forward pass.
 
124
  """
125
  if not crops:
126
  return []
127
 
128
  ocr = _get_crnn()
129
+ if ocr is None or _crnn_decode is None:
130
+ return [''] * len(crops)
131
 
132
  try:
133
  import torch
 
147
  with torch.no_grad():
148
  outputs = ocr.model(batch)
149
 
150
+ decoded = _crnn_decode(outputs.cpu(), ocr.idx_to_char, method='greedy')
151
+ return [d.strip() for d in decoded]
152
 
153
  except Exception as e:
154
  print(f'[template_matcher] CRNN batch error: {e}; falling back to serial')
 
1285
  img_l = Image.open(image_path).convert('L')
1286
  w, h = img_l.size
1287
  title_crop = img_l.crop((0, int(h * 0.04), w, int(h * 0.15)))
1288
+ title = _crnn_read(title_crop).upper()
 
1289
 
1290
  if title:
1291
  if 'LIVE BIRTH' in title or ('BIRTH' in title and 'DEATH' not in title and 'MARRIAGE' not in title):
 
1365
  assist_text = _paddle_read(crop)
1366
  assist_texts.append(assist_text)
1367
 
1368
+ crnn_texts = _crnn_read_batch(crops)
 
 
 
 
 
 
 
 
 
1369
 
1370
+ for field_name, crnn_text, assist_text in zip(field_names, crnn_texts, assist_texts):
1371
  final_text = _smart_merge(field_name, crnn_text, assist_text)
1372
  if final_text:
1373
  fields[field_name] = final_text
 
1374
 
1375
  print(f'[template_matcher] Extracted: {len(fields)}/{len(template)} fields')
1376
  paddle_count = sum(1 for m in debug_methods.values() if m == 'paddle-detect')
 
1382
 
1383
  fields['_quality'] = quality
1384
  fields['_corrections'] = corrections
 
1385
  return fields
1386
  except Exception as e:
1387
  print(f'[template_matcher] extract_fields error: {e}')