RenAI / vit.py
Arsh124's picture
Added support for HF-Model
9a88738
import os
import numpy as np
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Trainer, TrainingArguments, EarlyStoppingCallback
from PIL import Image
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
import torch.nn.functional as F
from evaluate import load
import albumentations as A
import os
from configs import model_path, processor_path, hf_token
# Enable mixed precision training
torch.backends.cudnn.benchmark = True
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# Load metrics
cer_metric = load("cer")
wer_metric = load("wer")
processor = TrOCRProcessor.from_pretrained(processor_path, do_rescale=False,use_fast=True, token=hf_token)
model = VisionEncoderDecoderModel.from_pretrained(model_path,use_safetensors=True, token=hf_token)
def compute_metrics(eval_pred):
logits, labels = eval_pred
if isinstance(logits, tuple):
logits = logits[0]
predictions = logits.argmax(-1)
decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = []
for label in labels:
label_filtered = [token for token in label if token != -100]
decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True)
decoded_labels.append(decoded_label)
cer_score = cer_metric.compute(predictions=decoded_preds, references=decoded_labels)
wer_score = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
return {"cer": cer_score, "wer": wer_score}
class LineDataset(Dataset):
def __init__(self, processor, model, line_images, texts, target_size=(384, 96), max_length=512, apply_augmentation=False):
self.line_images = line_images
self.texts = texts
self.processor = processor
self.processor.image_processor.max_length = max_length
self.processor.tokenizer.model_max_length = max_length
self.model = model
self.model.config.max_length = max_length
self.target_size = target_size
self.max_length = max_length
self.apply_augmentation = apply_augmentation
if apply_augmentation:
self.transform = A.Compose([
A.OneOf([
A.Rotate(limit=2, p=1.0),
A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=0.3, p=1.0),
A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0),
A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0),
A.Affine(scale=(0.95, 1.05), translate_percent=(0.02, 0.02), shear=(-2, 2), p=1.0),
A.Perspective(scale=(0.01, 0.03), p=1.0),
A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
A.GaussianBlur(blur_limit=(3, 7), p=1.0),
A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0),
A.MedianBlur(blur_limit=3, p=1.0),
], p=0.7),
])
else:
self.transform = A.Compose([])
def __len__(self):
return len(self.line_images)
def __getitem__(self, idx):
image = self.line_images[idx]
text = self.texts[idx]
if isinstance(image, Image.Image):
image = np.array(image)
if image.ndim == 2:
image = np.expand_dims(image, axis=-1)
image = np.repeat(image, 3, axis=-1)
image = (image * 255).astype(np.uint8)
if self.apply_augmentation and self.transform:
augmented = self.transform(image=image)
image = augmented['image']
image = Image.fromarray(image)
image = image.resize(self.target_size, Image.LANCZOS)
image = np.array(image) / 255.0
image = np.transpose(image, (2, 0, 1))
encoding = self.processor(images=image, text=text, return_tensors="pt")
encoding['labels'] = encoding['labels'][:, :self.max_length]
encoding = {k: v.squeeze() for k, v in encoding.items()}
return encoding
def collate_fn(batch):
pixel_values = torch.stack([item['pixel_values'] for item in batch])
labels = pad_sequence([item['labels'] for item in batch], batch_first=True, padding_value=-100)
return {'pixel_values': pixel_values, 'labels': labels}