| | from typing import Optional, Union |
| | from PIL import Image |
| | from torchvision import transforms |
| | import torch |
| |
|
| |
|
| | def get_default_transform(img_size: int = 224) -> transforms.Compose: |
| | """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet. |
| | |
| | Args: |
| | img_size: Tamanho da imagem de entrada do modelo (default: 224) |
| | |
| | Returns: |
| | Compose de transforms para preprocessamento |
| | """ |
| | |
| | resize_size = int(img_size * 256 / 224) |
| | return transforms.Compose([ |
| | transforms.Resize(resize_size), |
| | transforms.CenterCrop(img_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| |
|
| |
|
| | def preprocess_image( |
| | image: Union[str, Image.Image], |
| | transform: Optional[transforms.Compose] = None, |
| | ) -> torch.Tensor: |
| | """Carrega e transforma uma imagem (caminho ou PIL) retornando um tensor 1xCxHxW.""" |
| | transform = transform or get_default_transform() |
| |
|
| | if isinstance(image, str): |
| | img = Image.open(image).convert('RGB') |
| | elif isinstance(image, Image.Image): |
| | img = image.convert('RGB') |
| | else: |
| | raise ValueError("Imagem inválida: informe caminho ou PIL.Image") |
| |
|
| | return transform(img).unsqueeze(0) |
| |
|