clean_vs_messy / app.py
Nightfury16's picture
remamed cm_config.yaml
f258d44
import os
os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers'
os.environ['HF_HOME'] = '/data/.cache/huggingface'
os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib'
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
from typing import Dict, Tuple
MODEL_CHECKPOINTS = {
"convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth",
"efficientnet_b0": "checkpoints/effnet_b0_best.pth",
"efficientnet_b3": "checkpoints/effnet_b3_best.pth",
"vit_b_16": "checkpoints/vit_b_16_best.pth"
}
DEFAULT_MODEL_NAME = "vit_b_16"
MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name: str, num_classes: int) -> nn.Module:
model = None
if model_name == "efficientnet_b0":
model = models.efficientnet_b0(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "efficientnet_b3":
model = models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
num_ftrs = model.heads.head.in_features
model.heads.head = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
else:
raise ValueError(f"Model '{model_name}' not supported.")
return model
def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_name_from_ckpt = checkpoint['model_name']
model = get_model(model_name_from_ckpt, num_classes=1)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
return model, {}
print("--- Loading all models into memory ---")
for display_name, ckpt_path in MODEL_CHECKPOINTS.items():
if os.path.exists(ckpt_path):
model, _ = load_checkpoint(ckpt_path, DEVICE)
MODELS[display_name] = model
print(f"Loaded '{display_name}' on {DEVICE}.")
else:
print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.")
if not MODELS:
raise RuntimeError("No models were loaded. Please check your checkpoints directory.")
with open('cm_config.yaml', 'r') as f:
config = yaml.safe_load(f)
IMG_SIZE = config['data_params']['image_size']
inference_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image, model_name: str):
if pil_image is None: return None
model = MODELS[model_name]
pil_image = pil_image.convert("RGB")
image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(image_tensor)
prob = torch.sigmoid(output).item()
return {"clean": 1 - prob, "messy": prob}
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL_NAME,
label="Select Model"
)
],
outputs=gr.Label(num_top_classes=2, label="Predictions"),
title="Messy vs Clean Image Classifier",
description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.",
)
iface.launch()