from typing import Optional, Union from PIL import Image from torchvision import transforms import torch def get_default_transform(img_size: int = 224) -> transforms.Compose: """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet. Args: img_size: Tamanho da imagem de entrada do modelo (default: 224) Returns: Compose de transforms para preprocessamento """ # Resize proporcional: 256 para 224, escala para outros tamanhos resize_size = int(img_size * 256 / 224) return transforms.Compose([ transforms.Resize(resize_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def preprocess_image( image: Union[str, Image.Image], transform: Optional[transforms.Compose] = None, ) -> torch.Tensor: """Carrega e transforma uma imagem (caminho ou PIL) retornando um tensor 1xCxHxW.""" transform = transform or get_default_transform() if isinstance(image, str): img = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): img = image.convert('RGB') else: raise ValueError("Imagem inválida: informe caminho ou PIL.Image") return transform(img).unsqueeze(0)