testing_spaces / app /predict.py
Klasta's picture
First commit
5c6e546
raw
history blame contribute delete
923 Bytes
# app/predict.py
import torch
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from .model import CatvsDogResNet50
transform = A.Compose(
[
A.Resize(224, 224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
device = "cuda" if torch.cuda.is_available() else "cpu" # Fallback to CPU if no GPU
model = CatvsDogResNet50(freeze_backbone=True)
model.load_state_dict(torch.load(r"app/cd.pt", map_location=device))
model = model.to(device).eval()
def predict_image(image: Image.Image) -> str:
img = image.convert("RGB")
img = np.array(img)
img = transform(image=img)["image"]
img = img.unsqueeze(0).to(device)
with torch.inference_mode():
out = model(img)
prob = torch.sigmoid(out).item()
pred = "dog" if prob > 0.5 else "cat"
return pred