Spaces:
Sleeping
Sleeping
File size: 4,508 Bytes
20ec8a2 9642a08 174d2d0 9642a08 56be534 174d2d0 9642a08 174d2d0 20ec8a2 174d2d0 56be534 20ec8a2 174d2d0 56be534 9642a08 56be534 174d2d0 56be534 c9bdaab 174d2d0 9642a08 174d2d0 56be534 174d2d0 9642a08 174d2d0 56be534 9642a08 174d2d0 9642a08 174d2d0 56be534 174d2d0 9642a08 174d2d0 56be534 174d2d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import os
os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
from typing import Dict, Tuple
MODEL_CHECKPOINTS = {
"ConvNeXt tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth",
"EfficientNet-B0": "checkpoints/effnet_b0_best.pth",
"EfficientNet-B3": "checkpoints/effnet_b3_best.pth",
"Vision Transformer B-16": "checkpoints/vit_b_16_best.pth"
}
DEFAULT_MODEL_NAME = "ConvNeXt tiny (Best)"
MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name: str, num_classes: int) -> nn.Module:
model = None
if model_name == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "efficientnet_b3":
model = models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
else:
raise ValueError(f"Model '{model_name}' not supported.")
return model
def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_name_from_ckpt = checkpoint['model_name']
class_to_idx = checkpoint['class_to_idx']
model = get_model(model_name_from_ckpt, num_classes=1)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
idx_to_class = {v: k for k, v in class_to_idx.items()}
return model, idx_to_class
print("--- Loading all models into memory ---")
for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
if os.path.exists(ckpt_path):
model, idx_to_class = load_checkpoint(ckpt_path, DEVICE)
MODELS[display_name] = (model, idx_to_class)
print(f"Loaded '{display_name}' on {DEVICE}.")
else:
print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")
if not MODELS:
raise RuntimeError("No models were loaded. Please check your checkpoints directory.")
with open('staging_config.yaml', 'r') as f:
config = yaml.safe_load(f)
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image, model_name: str):
if pil_image is None: return None
model, idx_to_class = MODELS[model_name]
pil_image = pil_image.convert("RGB")
image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(image_tensor)
prob = torch.sigmoid(output).item()
class_0_name = idx_to_class.get(0, "Class 0")
class_1_name = idx_to_class.get(1, "Class 1")
return {class_0_name: 1 - prob, class_1_name: prob}
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_NAME,
label="Select Model"
)
],
outputs=gr.Label(num_top_classes=2, label="Predictions"),
title="Multi-Model Image Classifier",
description="Upload an image and select a model to see its classification.",
)
iface.launch() |