File size: 3,488 Bytes
93ada5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8accbb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import os

CONFIG_PATH = 'staging_config.yaml'
CHECKPOINT_FILENAME = 'model.pt'


def get_model(model_name, num_classes):
    """Factory function to create a model shell for loading weights."""
    model = None
    if model_name == "efficientnet_b0":
        model = models.efficientnet_b0(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "resnet50":
        model = models.resnet50(weights=None)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=None)
        num_ftrs = model.heads.head.in_features
        model.heads.head = nn.Linear(num_ftrs, num_classes)
    else:
        raise ValueError(f"Model '{model_name}' not supported.")
    return model

def load_checkpoint(checkpoint_path, device):
    """Loads a checkpoint and returns the model and class mapping."""
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_name = checkpoint['model_name']
    class_to_idx = checkpoint['class_to_idx']
    
    model = get_model(model_name, num_classes=1) 
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()
    
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    return model, idx_to_class


try:
    with open(CONFIG_PATH, 'r') as f:
        config = yaml.safe_load(f)
except FileNotFoundError:
    raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, DEVICE)
print(f"Model loaded successfully on {DEVICE}.")
print(f"Class mapping: {IDX_TO_CLASS}")


IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(pil_image):
    """
    Runs prediction on a single PIL image and returns a dictionary of class probabilities.
    Gradio's `Label` component expects a dictionary format for its output.
    """
    if pil_image is None:
        return None
        
    pil_image = pil_image.convert("RGB")
    
    image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        output = MODEL(image_tensor)
        prob = torch.sigmoid(output).item()
        
    class_0_name = IDX_TO_CLASS.get(0, "Class 0")
    class_1_name = IDX_TO_CLASS.get(1, "Class 1")
    
    confidences = {
        class_0_name: 1 - prob,
        class_1_name: prob
    }
    return confidences



title = "Image Classifier API"
description = """
Upload an image and the model will predict its class. 
This model was trained to distinguish between two classes.
The API returns the probabilities for each class.
"""

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Label(num_top_classes=2, label="Predictions"),
    title=title,
    description=description,
)

iface.launch()