# app/predict.py import torch import numpy as np from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 from .model import CatvsDogResNet50 transform = A.Compose( [ A.Resize(224, 224), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ] ) device = "cuda" if torch.cuda.is_available() else "cpu" # Fallback to CPU if no GPU model = CatvsDogResNet50(freeze_backbone=True) model.load_state_dict(torch.load(r"app/cd.pt", map_location=device)) model = model.to(device).eval() def predict_image(image: Image.Image) -> str: img = image.convert("RGB") img = np.array(img) img = transform(image=img)["image"] img = img.unsqueeze(0).to(device) with torch.inference_mode(): out = model(img) prob = torch.sigmoid(out).item() pred = "dog" if prob > 0.5 else "cat" return pred