File size: 4,508 Bytes
20ec8a2
 
 
 
 
 
9642a08
 
 
174d2d0
9642a08
 
56be534
174d2d0
9642a08
174d2d0
20ec8a2
174d2d0
 
 
56be534
20ec8a2
174d2d0
 
56be534
9642a08
56be534
 
 
 
174d2d0
56be534
 
c9bdaab
174d2d0
 
 
9642a08
174d2d0
 
 
56be534
174d2d0
 
 
9642a08
174d2d0
 
56be534
 
9642a08
 
 
 
174d2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9642a08
174d2d0
 
 
56be534
174d2d0
 
 
9642a08
174d2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56be534
174d2d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os

os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'

import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
from typing import Dict, Tuple

MODEL_CHECKPOINTS = {
    "ConvNeXt tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
    "EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
    "EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
    "Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
}
DEFAULT_MODEL_NAME = "ConvNeXt tiny (Best)"

MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HFConvNeXtWrapper(nn.Module):
    def __init__(self, model_name, num_labels):
        super(HFConvNeXtWrapper, self).__init__()
        self.model = ConvNextV2ForImageClassification.from_pretrained(
            model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
    def forward(self, x):
        return self.model(x).logits

def get_model(model_name: str, num_classes: int) -> nn.Module:
    model = None
    if model_name == "efficientnet_b0":
        model = models.efficientnet_b0(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "efficientnet_b3":
        model = models.efficientnet_b3(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=None)
        num_ftrs = model.heads.head.in_features
        model.heads.head = nn.Linear(num_ftrs, num_classes)
    elif "convnextv2" in model_name:
        model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
    else:
        raise ValueError(f"Model '{model_name}' not supported.")
    return model

def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model_name_from_ckpt = checkpoint['model_name']
    class_to_idx = checkpoint['class_to_idx']
    model = get_model(model_name_from_ckpt, num_classes=1)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()
    idx_to_class = {v: k for k, v in class_to_idx.items()}
    return model, idx_to_class

print("--- Loading all models into memory ---")
for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
    if os.path.exists(ckpt_path):
        model, idx_to_class = load_checkpoint(ckpt_path, DEVICE)
        MODELS[display_name] = (model, idx_to_class)
        print(f"Loaded '{display_name}' on {DEVICE}.")
    else:
        print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")

if not MODELS:
    raise RuntimeError("No models were loaded. Please check your checkpoints directory.")

with open('staging_config.yaml', 'r') as f:
    config = yaml.safe_load(f)
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(pil_image, model_name: str):
    if pil_image is None: return None
    
    model, idx_to_class = MODELS[model_name]
    pil_image = pil_image.convert("RGB")
    image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        output = model(image_tensor)
        prob = torch.sigmoid(output).item()
        
    class_0_name = idx_to_class.get(0, "Class 0")
    class_1_name = idx_to_class.get(1, "Class 1")
    return {class_0_name: 1 - prob, class_1_name: prob}

iface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Dropdown(
            choices=list(MODELS.keys()),
            value=DEFAULT_MODEL_NAME,
            label="Select Model"
        )
    ],
    outputs=gr.Label(num_top_classes=2, label="Predictions"),
    title="Multi-Model Image Classifier",
    description="Upload an image and select a model to see its classification.",
)

iface.launch()