File size: 1,332 Bytes
e11edb1
 
 
 
 
 
98cf39b
 
 
 
 
 
 
 
 
 
 
e11edb1
98cf39b
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Optional, Union
from PIL import Image
from torchvision import transforms
import torch


def get_default_transform(img_size: int = 224) -> transforms.Compose:
    """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet.
    
    Args:
        img_size: Tamanho da imagem de entrada do modelo (default: 224)
    
    Returns:
        Compose de transforms para preprocessamento
    """
    # Resize proporcional: 256 para 224, escala para outros tamanhos
    resize_size = int(img_size * 256 / 224)
    return transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def preprocess_image(
    image: Union[str, Image.Image],
    transform: Optional[transforms.Compose] = None,
) -> torch.Tensor:
    """Carrega e transforma uma imagem (caminho ou PIL) retornando um tensor 1xCxHxW."""
    transform = transform or get_default_transform()

    if isinstance(image, str):
        img = Image.open(image).convert('RGB')
    elif isinstance(image, Image.Image):
        img = image.convert('RGB')
    else:
        raise ValueError("Imagem inválida: informe caminho ou PIL.Image")

    return transform(img).unsqueeze(0)