File size: 12,957 Bytes
7111e1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""
Utility Functions for CRNN+CTC Civil Registry OCR
Includes CTC decoding, metrics calculation, and helper functions
"""

import torch
import numpy as np
def _editdistance(a, b):
    """Pure-Python Levenshtein distance β€” replaces the editdistance C extension."""
    m, n = len(a), len(b)
    dp = list(range(n + 1))
    for i in range(1, m + 1):
        prev, dp[0] = dp[0], i
        for j in range(1, n + 1):
            prev, dp[j] = dp[j], prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1])
    return dp[n]
from typing import List, Dict, Tuple


def decode_ctc_predictions(outputs, idx_to_char, method='greedy'):
    """
    Decode CTC predictions to text
    
    Args:
        outputs: Model outputs [seq_len, batch, num_chars]
        idx_to_char: Dictionary mapping indices to characters
        method: 'greedy' or 'beam_search'
        
    Returns:
        List of decoded strings
    """
    if method == 'greedy':
        return greedy_decode(outputs, idx_to_char)
    elif method == 'beam_search':
        return beam_search_decode(outputs, idx_to_char)
    else:
        raise ValueError(f"Unknown decoding method: {method}")


def greedy_decode(outputs, idx_to_char):
    """
    Greedy CTC decoding - fast but less accurate
    """
    # Get most probable characters
    pred_indices = torch.argmax(outputs, dim=2)  # [seq_len, batch]
    pred_indices = pred_indices.permute(1, 0)  # [batch, seq_len]
    
    decoded_texts = []
    
    for sequence in pred_indices:
        chars = []
        prev_idx = -1
        
        for idx in sequence:
            idx = idx.item()
            # Skip blank (0) and consecutive duplicates
            if idx != 0 and idx != prev_idx:
                if idx in idx_to_char:
                    chars.append(idx_to_char[idx])
            prev_idx = idx
        
        decoded_texts.append(''.join(chars))
    
    return decoded_texts


def beam_search_decode(outputs, idx_to_char, beam_width=10):
    """
    Beam search CTC decoding - slower but more accurate.

    FIXED Bug 6: previous code mixed list-of-chars and string representations.
    After sorting new_beams (a dict keyed by strings), it did `list(seq)` on the
    string key β€” which splits a string like "AB" into ['A','B'] accidentally works
    for ASCII but is fragile and confusing. Rewritten to use strings throughout:
    beams are now List[Tuple[str, float]] with the sequence always kept as a plain
    string, eliminating the list/string ambiguity entirely.
    """
    outputs = torch.nn.functional.softmax(outputs, dim=2)
    outputs = outputs.permute(1, 0, 2).cpu().numpy()  # [batch, seq_len, num_chars]

    decoded_texts = []

    for output in outputs:
        # Each beam is (sequence_string, cumulative_probability)
        beams: list = [('', 1.0)]

        for timestep in output:
            new_beams: dict = {}

            for sequence, prob in beams:
                for idx, char_prob in enumerate(timestep):
                    if idx == 0:  # blank token β€” sequence unchanged
                        new_seq = sequence
                    elif idx in idx_to_char:
                        char = idx_to_char[idx]
                        # CTC rule: merge consecutive duplicate characters
                        if sequence and sequence[-1] == char:
                            new_seq = sequence        # duplicate β€” stay the same
                        else:
                            new_seq = sequence + char # append directly to string
                    else:
                        continue

                    new_prob = prob * char_prob
                    # Merge beams that produce the same string
                    if new_seq in new_beams:
                        new_beams[new_seq] = max(new_beams[new_seq], new_prob)
                    else:
                        new_beams[new_seq] = new_prob

            # Keep top-k beams; keys are already strings β€” no list() conversion needed
            beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]

        # Best sequence is the string with highest probability
        best_sequence = max(beams, key=lambda x: x[1])[0]
        decoded_texts.append(best_sequence)

    return decoded_texts


def calculate_cer(predictions: List[str], ground_truths: List[str]) -> float:
    """
    Calculate Character Error Rate (CER)
    
    CER = (Substitutions + Deletions + Insertions) / Total Characters
    """
    if len(predictions) != len(ground_truths):
        raise ValueError("Predictions and ground truths must have same length")
    
    total_distance = 0
    total_length = 0
    
    for pred, gt in zip(predictions, ground_truths):
        distance = _editdistance(pred, gt)
        total_distance += distance
        total_length += len(gt)
    
    cer = (total_distance / total_length * 100) if total_length > 0 else 0
    return cer


def calculate_wer(predictions: List[str], ground_truths: List[str]) -> float:
    """
    Calculate Word Error Rate (WER)
    
    WER = (Substitutions + Deletions + Insertions) / Total Words
    """
    if len(predictions) != len(ground_truths):
        raise ValueError("Predictions and ground truths must have same length")
    
    total_distance = 0
    total_length = 0
    
    for pred, gt in zip(predictions, ground_truths):
        pred_words = pred.split()
        gt_words = gt.split()
        
        distance = _editdistance(pred_words, gt_words)
        total_distance += distance
        total_length += len(gt_words)
    
    wer = (total_distance / total_length * 100) if total_length > 0 else 0
    return wer


def calculate_accuracy(predictions: List[str], ground_truths: List[str]) -> float:
    """
    Calculate exact match accuracy
    """
    if len(predictions) != len(ground_truths):
        raise ValueError("Predictions and ground truths must have same length")
    
    correct = sum(1 for pred, gt in zip(predictions, ground_truths) if pred == gt)
    accuracy = (correct / len(predictions) * 100) if len(predictions) > 0 else 0
    
    return accuracy


class EarlyStopping:
    """
    Early stopping to stop training when validation loss stops improving
    """
    
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        
        return self.early_stop


class AverageMeter:
    """
    Computes and stores the average and current value
    """
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def calculate_confusion_matrix(predictions: List[str], ground_truths: List[str], char_set: List[str]) -> np.ndarray:
    """
    Calculate character-level confusion matrix
    
    Args:
        predictions: List of predicted strings
        ground_truths: List of ground truth strings
        char_set: List of all possible characters
        
    Returns:
        Confusion matrix [num_chars, num_chars]
    """
    char_to_idx = {char: idx for idx, char in enumerate(char_set)}
    n_chars = len(char_set)
    
    confusion = np.zeros((n_chars, n_chars), dtype=np.int64)
    
    for pred, gt in zip(predictions, ground_truths):
        # Align sequences (simple alignment)
        max_len = max(len(pred), len(gt))
        pred_padded = pred + ' ' * (max_len - len(pred))
        gt_padded = gt + ' ' * (max_len - len(gt))
        
        for p_char, g_char in zip(pred_padded, gt_padded):
            if p_char in char_to_idx and g_char in char_to_idx:
                confusion[char_to_idx[g_char], char_to_idx[p_char]] += 1
    
    return confusion


def extract_form_fields(text: str, form_type: str) -> Dict[str, str]:
    """
    Extract specific fields from recognized text based on form type
    
    Args:
        text: Recognized text
        form_type: 'form1a', 'form2a', 'form3a', 'form90'
        
    Returns:
        Dictionary of extracted fields
    """
    fields = {}
    
    if form_type == 'form1a':  # Birth Certificate
        # Extract common fields (simplified)
        # In practice, use NER or regex patterns
        fields['type'] = 'Birth Certificate'
        # Add more field extraction logic
        
    elif form_type == 'form2a':  # Death Certificate
        fields['type'] = 'Death Certificate'
        
    elif form_type == 'form3a':  # Marriage Certificate
        fields['type'] = 'Marriage Certificate'
        
    elif form_type == 'form90':  # Marriage License Application
        fields['type'] = 'Marriage License Application'
    
    return fields


def validate_extracted_data(data: Dict[str, str], form_type: str) -> Tuple[bool, List[str]]:
    """
    Validate extracted data for completeness and format
    
    Args:
        data: Extracted data dictionary
        form_type: Form type
        
    Returns:
        (is_valid, list_of_errors)
    """
    errors = []
    
    # Define required fields per form type
    required_fields = {
        'form1a': ['name', 'date_of_birth', 'place_of_birth'],
        'form2a': ['name', 'date_of_death', 'place_of_death'],
        'form3a': ['husband_name', 'wife_name', 'date_of_marriage'],
        'form90': ['husband_name', 'wife_name', 'date_of_application']
    }
    
    # Check required fields
    for field in required_fields.get(form_type, []):
        if field not in data or not data[field]:
            errors.append(f"Missing required field: {field}")
    
    # Additional validation can be added here
    # - Date format validation
    # - Name format validation
    # - etc.
    
    is_valid = len(errors) == 0
    return is_valid, errors


def load_checkpoint(checkpoint_path, model, optimizer=None, device='cpu'):
    """
    Load model checkpoint
    
    Args:
        checkpoint_path: Path to checkpoint file
        model: Model instance
        optimizer: Optimizer instance (optional)
        device: Device to load to
        
    Returns:
        (model, optimizer, checkpoint_dict)
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f"βœ“ Loaded checkpoint from {checkpoint_path}")
    print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
    if 'val_cer' in checkpoint:
        print(f"  Val CER  : {checkpoint['val_cer']:.4f}%")
    elif 'val_loss' in checkpoint:
        print(f"  Val Loss : {checkpoint['val_loss']:.4f}  (run compare_live_cer.py for true CER)")
    else:
        print(f"  Val CER  : N/A  (run compare_live_cer.py for true CER)")
    
    return model, optimizer, checkpoint


def save_predictions_to_file(predictions: List[str], ground_truths: List[str], output_file: str):
    """
    Save predictions and ground truths to file for analysis
    """
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("Ground Truth\tPrediction\tMatch\n")
        f.write("=" * 80 + "\n")
        
        for gt, pred in zip(ground_truths, predictions):
            match = "βœ“" if gt == pred else "βœ—"
            f.write(f"{gt}\t{pred}\t{match}\n")
    
    print(f"βœ“ Predictions saved to {output_file}")


if __name__ == "__main__":
    # Test utility functions
    print("=" * 60)
    print("Testing Utility Functions")
    print("=" * 60)
    
    # Test CER calculation
    predictions = ["Hello World", "Test", "Sample Text"]
    ground_truths = ["Hello World", "Tset", "Sample Txt"]
    
    cer = calculate_cer(predictions, ground_truths)
    wer = calculate_wer(predictions, ground_truths)
    accuracy = calculate_accuracy(predictions, ground_truths)
    
    print(f"\nMetrics:")
    print(f"  CER: {cer:.2f}%")
    print(f"  WER: {wer:.2f}%")
    print(f"  Accuracy: {accuracy:.2f}%")
    
    # Test early stopping
    print("\nTesting Early Stopping:")
    early_stopping = EarlyStopping(patience=3, min_delta=0.001)
    
    val_losses = [1.0, 0.9, 0.85, 0.84, 0.84, 0.84, 0.84]
    for epoch, loss in enumerate(val_losses, 1):
        should_stop = early_stopping(loss)
        print(f"  Epoch {epoch}: Loss = {loss:.2f}, Stop = {should_stop}")
        if should_stop:
            break