Spaces:
Sleeping
Sleeping
| # app/predict.py | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from .model import CatvsDogResNet50 | |
| transform = A.Compose( | |
| [ | |
| A.Resize(224, 224), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ] | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # Fallback to CPU if no GPU | |
| model = CatvsDogResNet50(freeze_backbone=True) | |
| model.load_state_dict(torch.load(r"app/cd.pt", map_location=device)) | |
| model = model.to(device).eval() | |
| def predict_image(image: Image.Image) -> str: | |
| img = image.convert("RGB") | |
| img = np.array(img) | |
| img = transform(image=img)["image"] | |
| img = img.unsqueeze(0).to(device) | |
| with torch.inference_mode(): | |
| out = model(img) | |
| prob = torch.sigmoid(out).item() | |
| pred = "dog" if prob > 0.5 else "cat" | |
| return pred | |