import os import numpy as np from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch from torch.utils.data import Dataset, DataLoader import torch from torch.utils.data import Dataset, DataLoader, random_split from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments, EarlyStoppingCallback from PIL import Image import numpy as np from torch.nn.utils.rnn import pad_sequence from torch.optim import AdamW import torch.nn.functional as F from evaluate import load import albumentations as A import os from configs import model_path, processor_path, hf_token # Enable mixed precision training torch.backends.cudnn.benchmark = True os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Load metrics cer_metric = load("cer") wer_metric = load("wer") processor = TrOCRProcessor.from_pretrained(processor_path, do_rescale=False,use_fast=True, token=hf_token) model = VisionEncoderDecoderModel.from_pretrained(model_path,use_safetensors=True, token=hf_token) def compute_metrics(eval_pred): logits, labels = eval_pred if isinstance(logits, tuple): logits = logits[0] predictions = logits.argmax(-1) decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True) decoded_labels = [] for label in labels: label_filtered = [token for token in label if token != -100] decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True) decoded_labels.append(decoded_label) cer_score = cer_metric.compute(predictions=decoded_preds, references=decoded_labels) wer_score = wer_metric.compute(predictions=decoded_preds, references=decoded_labels) return {"cer": cer_score, "wer": wer_score} class LineDataset(Dataset): def __init__(self, processor, model, line_images, texts, target_size=(384, 96), max_length=512, apply_augmentation=False): self.line_images = line_images self.texts = texts self.processor = processor self.processor.image_processor.max_length = max_length self.processor.tokenizer.model_max_length = max_length self.model = model self.model.config.max_length = max_length self.target_size = target_size self.max_length = max_length self.apply_augmentation = apply_augmentation if apply_augmentation: self.transform = A.Compose([ A.OneOf([ A.Rotate(limit=2, p=1.0), A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=0.3, p=1.0), A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0), A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0), A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), shear=(-2, 2), p=1.0), A.Perspective(scale=(0.01, 0.03), p=1.0), A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0), A.GaussianBlur(blur_limit=(3, 7), p=1.0), A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0), A.MedianBlur(blur_limit=3, p=1.0), ], p=0.7), ]) else: self.transform = A.Compose([]) def __len__(self): return len(self.line_images) def __getitem__(self, idx): image = self.line_images[idx] text = self.texts[idx] if isinstance(image, Image.Image): image = np.array(image) if image.ndim == 2: image = np.expand_dims(image, axis=-1) image = np.repeat(image, 3, axis=-1) image = (image * 255).astype(np.uint8) if self.apply_augmentation and self.transform: augmented = self.transform(image=image) image = augmented['image'] image = Image.fromarray(image) image = image.resize(self.target_size, Image.LANCZOS) image = np.array(image) / 255.0 image = np.transpose(image, (2, 0, 1)) encoding = self.processor(images=image, text=text, return_tensors="pt") encoding['labels'] = encoding['labels'][:, :self.max_length] encoding = {k: v.squeeze() for k, v in encoding.items()} return encoding def collate_fn(batch): pixel_values = torch.stack([item['pixel_values'] for item in batch]) labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100) return {'pixel_values': pixel_values, 'labels': labels}