ViTViz / utils /preprocessing.py
lucasddmc's picture
feat: adds different models capability
98cf39b
from typing import Optional, Union
from PIL import Image
from torchvision import transforms
import torch
def get_default_transform(img_size: int = 224) -> transforms.Compose:
"""Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet.
Args:
img_size: Tamanho da imagem de entrada do modelo (default: 224)
Returns:
Compose de transforms para preprocessamento
"""
# Resize proporcional: 256 para 224, escala para outros tamanhos
resize_size = int(img_size * 256 / 224)
return transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def preprocess_image(
image: Union[str, Image.Image],
transform: Optional[transforms.Compose] = None,
) -> torch.Tensor:
"""Carrega e transforma uma imagem (caminho ou PIL) retornando um tensor 1xCxHxW."""
transform = transform or get_default_transform()
if isinstance(image, str):
img = Image.open(image).convert('RGB')
elif isinstance(image, Image.Image):
img = image.convert('RGB')
else:
raise ValueError("Imagem inválida: informe caminho ou PIL.Image")
return transform(img).unsqueeze(0)