swin-fish-classification / image_processing_swin.py
Paul
Upload processor
a614a9b verified
from torchvision import transforms
from transformers import ViTImageProcessor
import torch
from PIL import Image
class MyCustomSwinProcessor(ViTImageProcessor):
def resize_and_pad(self, image, target_size=224):
"""Resize image preserving aspect ratio, then pad to target size."""
# Get original dimensions
w, h = image.size
# Calculate scaling factor to fit within target_size while preserving aspect ratio
scale = min(target_size / w, target_size / h)
# New dimensions after scaling
new_w = int(w * scale)
new_h = int(h * scale)
# Resize the image
image = image.resize((new_w, new_h), Image.BILINEAR)
# Calculate padding needed
pad_w = target_size - new_w
pad_h = target_size - new_h
# Distribute padding evenly on both sides
left = pad_w // 2
right = pad_w - left
top = pad_h // 2
bottom = pad_h - top
# Pad with white because its the dataset default background color
return transforms.functional.pad(image, (left, top, right, bottom), fill=255)
def preprocess(self, images, **kwargs):
images = [self.resize_and_pad(image, target_size=224) for image in images]
images = [transforms.ToTensor()(image) for image in images]
images = torch.stack(images)
return super().preprocess(images, **kwargs)