File size: 3,488 Bytes
93ada5f 8accbb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import os
CONFIG_PATH = 'staging_config.yaml'
CHECKPOINT_FILENAME = 'model.pt'
def get_model(model_name, num_classes):
"""Factory function to create a model shell for loading weights."""
model = None
if model_name == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "resnet50":
model = models.resnet50(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, num_classes)
else:
raise ValueError(f"Model '{model_name}' not supported.")
return model
def load_checkpoint(checkpoint_path, device):
"""Loads a checkpoint and returns the model and class mapping."""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_name = checkpoint['model_name']
class_to_idx = checkpoint['class_to_idx']
model = get_model(model_name, num_classes=1)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
idx_to_class = {v: k for k, v in class_to_idx.items()}
return model, idx_to_class
try:
with open(CONFIG_PATH, 'r') as f:
config = yaml.safe_load(f)
except FileNotFoundError:
raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, DEVICE)
print(f"Model loaded successfully on {DEVICE}.")
print(f"Class mapping: {IDX_TO_CLASS}")
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image):
"""
Runs prediction on a single PIL image and returns a dictionary of class probabilities.
Gradio's `Label` component expects a dictionary format for its output.
"""
if pil_image is None:
return None
pil_image = pil_image.convert("RGB")
image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = MODEL(image_tensor)
prob = torch.sigmoid(output).item()
class_0_name = IDX_TO_CLASS.get(0, "Class 0")
class_1_name = IDX_TO_CLASS.get(1, "Class 1")
confidences = {
class_0_name: 1 - prob,
class_1_name: prob
}
return confidences
title = "Image Classifier API"
description = """
Upload an image and the model will predict its class.
This model was trained to distinguish between two classes.
The API returns the probabilities for each class.
"""
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Predictions"),
title=title,
description=description,
)
iface.launch() |