|
|
import os |
|
|
import numpy as np |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments, EarlyStoppingCallback |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torch.optim import AdamW |
|
|
import torch.nn.functional as F |
|
|
from evaluate import load |
|
|
import albumentations as A |
|
|
import os |
|
|
|
|
|
from configs import model_path, processor_path, hf_token |
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
|
|
|
|
|
|
cer_metric = load("cer") |
|
|
wer_metric = load("wer") |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(processor_path, do_rescale=False,use_fast=True, token=hf_token) |
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_path,use_safetensors=True, token=hf_token) |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
if isinstance(logits, tuple): |
|
|
logits = logits[0] |
|
|
predictions = logits.argmax(-1) |
|
|
decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True) |
|
|
decoded_labels = [] |
|
|
for label in labels: |
|
|
label_filtered = [token for token in label if token != -100] |
|
|
decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True) |
|
|
decoded_labels.append(decoded_label) |
|
|
cer_score = cer_metric.compute(predictions=decoded_preds, references=decoded_labels) |
|
|
wer_score = wer_metric.compute(predictions=decoded_preds, references=decoded_labels) |
|
|
return {"cer": cer_score, "wer": wer_score} |
|
|
|
|
|
class LineDataset(Dataset): |
|
|
def __init__(self, processor, model, line_images, texts, target_size=(384, 96), max_length=512, apply_augmentation=False): |
|
|
self.line_images = line_images |
|
|
self.texts = texts |
|
|
self.processor = processor |
|
|
self.processor.image_processor.max_length = max_length |
|
|
self.processor.tokenizer.model_max_length = max_length |
|
|
self.model = model |
|
|
self.model.config.max_length = max_length |
|
|
self.target_size = target_size |
|
|
self.max_length = max_length |
|
|
self.apply_augmentation = apply_augmentation |
|
|
|
|
|
if apply_augmentation: |
|
|
self.transform = A.Compose([ |
|
|
A.OneOf([ |
|
|
A.Rotate(limit=2, p=1.0), |
|
|
A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=0.3, p=1.0), |
|
|
A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0), |
|
|
A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0), |
|
|
A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), shear=(-2, 2), p=1.0), |
|
|
A.Perspective(scale=(0.01, 0.03), p=1.0), |
|
|
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0), |
|
|
A.GaussianBlur(blur_limit=(3, 7), p=1.0), |
|
|
A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0), |
|
|
A.MedianBlur(blur_limit=3, p=1.0), |
|
|
], p=0.7), |
|
|
]) |
|
|
else: |
|
|
self.transform = A.Compose([]) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.line_images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.line_images[idx] |
|
|
text = self.texts[idx] |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
|
|
|
if image.ndim == 2: |
|
|
image = np.expand_dims(image, axis=-1) |
|
|
image = np.repeat(image, 3, axis=-1) |
|
|
|
|
|
image = (image * 255).astype(np.uint8) |
|
|
|
|
|
if self.apply_augmentation and self.transform: |
|
|
augmented = self.transform(image=image) |
|
|
image = augmented['image'] |
|
|
|
|
|
image = Image.fromarray(image) |
|
|
image = image.resize(self.target_size, Image.LANCZOS) |
|
|
image = np.array(image) / 255.0 |
|
|
image = np.transpose(image, (2, 0, 1)) |
|
|
|
|
|
encoding = self.processor(images=image, text=text, return_tensors="pt") |
|
|
encoding['labels'] = encoding['labels'][:, :self.max_length] |
|
|
encoding = {k: v.squeeze() for k, v in encoding.items()} |
|
|
return encoding |
|
|
|
|
|
def collate_fn(batch): |
|
|
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
|
|
labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100) |
|
|
return {'pixel_values': pixel_values, 'labels': labels} |
|
|
|