| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| import os |
| from tqdm import tqdm |
| import numpy as np |
| from pathlib import Path |
| import json |
|
|
| from crnn_model import get_crnn_model, initialize_weights |
| from dataset import CivilRegistryDataset, collate_fn |
| from utils import ( |
| decode_ctc_predictions, |
| calculate_cer, |
| calculate_wer, |
| EarlyStopping |
| ) |
|
|
|
|
| class CRNNTrainer: |
| """ |
| Trainer class for CRNN+CTC model |
| """ |
| |
| def __init__(self, config): |
| self.config = config |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| |
| self.checkpoint_dir = Path(config['checkpoint_dir']) |
| self.log_dir = Path(config['log_dir']) |
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| self.log_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| print("Loading datasets...") |
| self.train_dataset = CivilRegistryDataset( |
| data_dir=config['train_data_dir'], |
| annotations_file=config['train_annotations'], |
| img_height=config['img_height'], |
| img_width=config['img_width'], |
| augment=True, |
| form_type=config.get('form_type', 'all') |
| ) |
| |
| self.val_dataset = CivilRegistryDataset( |
| data_dir=config['val_data_dir'], |
| annotations_file=config['val_annotations'], |
| img_height=config['img_height'], |
| img_width=config['img_width'], |
| augment=False, |
| form_type=config.get('form_type', 'all') |
| ) |
| |
| |
| self.train_loader = DataLoader( |
| self.train_dataset, |
| batch_size=config['batch_size'], |
| shuffle=True, |
| num_workers=config['num_workers'], |
| collate_fn=collate_fn, |
| pin_memory=False |
| ) |
| |
| self.val_loader = DataLoader( |
| self.val_dataset, |
| batch_size=config['batch_size'], |
| shuffle=False, |
| num_workers=config['num_workers'], |
| collate_fn=collate_fn, |
| pin_memory=False |
| ) |
| |
| |
| print(f"Initializing model on {self.device}...") |
| self.model = get_crnn_model( |
| model_type=config.get('model_type', 'standard'), |
| img_height=config['img_height'], |
| num_chars=self.train_dataset.num_chars, |
| hidden_size=config['hidden_size'], |
| num_lstm_layers=config['num_lstm_layers'] |
| ) |
| |
| self.model = self.model.to(self.device) |
|
|
| |
| self.criterion = nn.CTCLoss(blank=0, zero_infinity=True) |
|
|
| |
| self.optimizer = optim.Adam( |
| self.model.parameters(), |
| lr=config['learning_rate'], |
| weight_decay=config.get('weight_decay', 1e-4) |
| ) |
|
|
| |
| |
| |
| warmup_epochs = config.get('warmup_epochs', 5) |
|
|
| def warmup_lambda(epoch): |
| if epoch < warmup_epochs: |
| return (epoch + 1) / warmup_epochs |
| return 1.0 |
|
|
| self.warmup_scheduler = optim.lr_scheduler.LambdaLR( |
| self.optimizer, lr_lambda=warmup_lambda) |
|
|
| |
| self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| self.optimizer, |
| mode='min', |
| factor=0.5, |
| patience=config.get('lr_patience', 5), |
| min_lr=1e-6 |
| ) |
| self._warmup_epochs = warmup_epochs |
|
|
| |
| self.early_stopping = EarlyStopping( |
| patience=config.get('early_stopping_patience', 10), |
| min_delta=config.get('min_delta', 0.001) |
| ) |
|
|
| |
| self.history = { |
| 'train_loss': [], |
| 'val_loss': [], |
| 'val_cer': [], |
| 'val_wer': [], |
| 'learning_rates': [] |
| } |
|
|
| |
| self.start_epoch = 1 |
| self.best_val_loss = float('inf') |
| resume_path = self.checkpoint_dir / 'latest_checkpoint.pth' |
|
|
| if resume_path.exists(): |
| print(f"\n Found checkpoint: {resume_path}") |
| print(f" Resuming training from last saved epoch...") |
| ckpt = torch.load(resume_path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt['model_state_dict']) |
| self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) |
| self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) |
| if 'warmup_scheduler_state_dict' in ckpt: |
| self.warmup_scheduler.load_state_dict(ckpt['warmup_scheduler_state_dict']) |
| self.start_epoch = ckpt['epoch'] + 1 |
| self.best_val_loss = ckpt.get('val_loss', float('inf')) |
| self.history = ckpt.get('history', self.history) |
| print(f" β Resumed from Epoch {ckpt['epoch']} " |
| f"(Val Loss: {ckpt['val_loss']:.4f}, CER: {ckpt['val_cer']:.2f}%)") |
| else: |
| print(f" No checkpoint found β starting fresh.") |
| initialize_weights(self.model) |
|
|
| print(f"β Model ready with {sum(p.numel() for p in self.model.parameters()):,} parameters") |
| |
| def train_epoch(self, epoch): |
| """Train for one epoch""" |
| self.model.train() |
| total_loss = 0 |
| |
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.config['epochs']}") |
| |
| nan_count = 0 |
| for batch_idx, (images, targets, target_lengths, _) in enumerate(pbar): |
| images = images.to(self.device) |
| targets = targets.to(self.device) |
|
|
| |
| self.optimizer.zero_grad() |
|
|
| |
| outputs = self.model(images) |
| |
| |
| log_probs = nn.functional.log_softmax(outputs, dim=2) |
| |
| |
| batch_size = images.size(0) |
| input_lengths = torch.full( |
| size=(batch_size,), |
| fill_value=outputs.size(0), |
| dtype=torch.long |
| ).to(self.device) |
| |
| |
| loss = self.criterion( |
| log_probs, |
| targets, |
| input_lengths, |
| target_lengths |
| ) |
|
|
| |
| if torch.isnan(loss) or torch.isinf(loss): |
| nan_count += 1 |
| continue |
|
|
| |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) |
| |
| self.optimizer.step() |
| |
| total_loss += loss.item() |
| |
| |
| pbar.set_postfix({ |
| 'loss': f'{loss.item():.4f}', |
| 'avg_loss': f'{total_loss / (batch_idx + 1):.4f}' |
| }) |
| if nan_count > 0: |
| print(f" [WARNING] {nan_count} NaN/Inf batches skipped this epoch.") |
| |
| avg_loss = total_loss / len(self.train_loader) |
| return avg_loss |
| |
| def validate(self): |
| """Validate the model""" |
| self.model.eval() |
| total_loss = 0 |
| all_predictions = [] |
| all_ground_truths = [] |
| |
| with torch.no_grad(): |
| for images, targets, target_lengths, texts in tqdm(self.val_loader, desc="Validating"): |
| images = images.to(self.device) |
| targets_gpu = targets.to(self.device) |
| |
| |
| outputs = self.model(images) |
| log_probs = nn.functional.log_softmax(outputs, dim=2) |
| |
| |
| batch_size = images.size(0) |
| input_lengths = torch.full( |
| size=(batch_size,), |
| fill_value=outputs.size(0), |
| dtype=torch.long |
| ).to(self.device) |
| |
| loss = self.criterion(log_probs, targets_gpu, input_lengths, target_lengths) |
| total_loss += loss.item() |
| |
| |
| predictions = decode_ctc_predictions( |
| outputs.cpu(), |
| self.train_dataset.idx_to_char |
| ) |
| |
| all_predictions.extend(predictions) |
| all_ground_truths.extend(texts) |
| |
| avg_loss = total_loss / len(self.val_loader) |
| |
| |
| cer = calculate_cer(all_predictions, all_ground_truths) |
| wer = calculate_wer(all_predictions, all_ground_truths) |
| |
| return avg_loss, cer, wer, all_predictions, all_ground_truths |
| |
| def train(self): |
| """Main training loop""" |
| print("\n" + "=" * 70) |
| print("Starting Training") |
| print("=" * 70) |
| |
| best_val_loss = self.best_val_loss |
|
|
| for epoch in range(self.start_epoch, self.config['epochs'] + 1): |
| print(f"\nEpoch {epoch}/{self.config['epochs']}") |
| print("-" * 70) |
| |
| |
| train_loss = self.train_epoch(epoch) |
| |
| |
| val_loss, val_cer, val_wer, predictions, ground_truths = self.validate() |
| |
| |
| |
| if epoch <= self._warmup_epochs: |
| self.warmup_scheduler.step() |
| else: |
| self.scheduler.step(val_loss) |
| current_lr = self.optimizer.param_groups[0]['lr'] |
| |
| |
| self.history['train_loss'].append(train_loss) |
| self.history['val_loss'].append(val_loss) |
| self.history['val_cer'].append(val_cer) |
| self.history['val_wer'].append(val_wer) |
| self.history['learning_rates'].append(current_lr) |
| |
| |
| print(f"\nMetrics:") |
| print(f" Train Loss: {train_loss:.4f}") |
| print(f" Val Loss: {val_loss:.4f}") |
| print(f" Val CER: {val_cer:.2f}%") |
| print(f" Val WER: {val_wer:.2f}%") |
| print(f" LR: {current_lr:.6f}") |
| |
| |
| print(f"\nSample Predictions:") |
| for i in range(min(3, len(predictions))): |
| print(f" GT: {ground_truths[i]}") |
| print(f" Pred: {predictions[i]}") |
| print() |
|
|
| |
| with torch.no_grad(): |
| sample_img = self.val_dataset[0][0].unsqueeze(0).to(self.device) |
| raw_out = self.model(sample_img) |
| probs = torch.softmax(raw_out, dim=2) |
| best_idx = probs[:, 0, :].argmax(dim=1) |
| best_prob = probs[:, 0, :].max(dim=1).values |
| blank_pct = (best_idx == 0).float().mean().item() * 100 |
| avg_conf = best_prob.mean().item() |
| non_blank = [self.train_dataset.idx_to_char.get(i.item(), '?') |
| for i in best_idx if i.item() != 0] |
| print(f" blank={blank_pct:.0f}% conf={avg_conf:.3f} " |
| f"chars={''.join(non_blank[:20])!r}") |
|
|
| |
| |
| is_best = val_loss < best_val_loss |
| if is_best: |
| best_val_loss = val_loss |
| |
| self.save_checkpoint(epoch, val_loss, val_cer, is_best) |
| |
| |
| if self.early_stopping(val_loss): |
| print(f"\nEarly stopping triggered at epoch {epoch}") |
| break |
| |
| print("\n" + "=" * 70) |
| print("Training Complete!") |
| print(f"Best validation loss: {best_val_loss:.4f}") |
| print("=" * 70) |
| |
| |
| self.save_history() |
| |
| def save_checkpoint(self, epoch, val_loss, val_cer, is_best=False): |
| """Save model checkpoint""" |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scheduler_state_dict': self.scheduler.state_dict(), |
| 'warmup_scheduler_state_dict': self.warmup_scheduler.state_dict(), |
| 'val_loss': val_loss, |
| 'val_cer': val_cer, |
| 'char_to_idx': self.train_dataset.char_to_idx, |
| 'idx_to_char': self.train_dataset.idx_to_char, |
| 'config': self.config, |
| 'history': self.history |
| } |
| |
| |
| checkpoint_path = self.checkpoint_dir / 'latest_checkpoint.pth' |
| torch.save(checkpoint, checkpoint_path) |
| |
| |
| if is_best: |
| best_path = self.checkpoint_dir / 'best_model.pth' |
| torch.save(checkpoint, best_path) |
| print(f" β Best model saved (Val Loss: {val_loss:.4f}, CER: {val_cer:.2f}%)") |
| |
| |
| if epoch % self.config.get('save_freq', 10) == 0: |
| epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pth' |
| epoch_ckpt = {k: v for k, v in checkpoint.items() if k != 'history'} |
| torch.save(epoch_ckpt, epoch_path) |
| |
| def save_history(self): |
| """Save training history""" |
| history_path = self.log_dir / 'training_history.json' |
| with open(history_path, 'w') as f: |
| json.dump(self.history, f, indent=2) |
| print(f"\nβ Training history saved to {history_path}") |
|
|
|
|
| def main(): |
| """Main training function""" |
| |
| |
| config = { |
| |
| 'train_data_dir': 'data/train', |
| 'train_annotations': 'data/train_annotations.json', |
| 'val_data_dir': 'data/val', |
| 'val_annotations': 'data/val_annotations.json', |
| 'form_type': 'all', |
| |
| |
| 'model_type': 'standard', |
| 'img_height': 64, |
| 'img_width': 512, |
| 'hidden_size': 128, |
| 'num_lstm_layers': 1, |
| |
| |
| 'batch_size': 32, |
| 'epochs': 100, |
| 'learning_rate': 0.0001, |
| 'weight_decay': 1e-4, |
| 'num_workers': 0, |
| 'warmup_epochs': 5, |
|
|
| |
| 'lr_patience': 5, |
| 'early_stopping_patience': 20, |
| 'min_delta': 0.001, |
| |
| |
| 'checkpoint_dir': 'checkpoints', |
| 'log_dir': 'logs', |
| 'save_freq': 10, |
| } |
| |
| |
| trainer = CRNNTrainer(config) |
| |
| |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| main() |