import os os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers' os.environ['HF_HOME'] = '/data/.cache/huggingface' os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib' import torch import torch.nn as nn import yaml from torchvision import models, transforms from PIL import Image import gradio as gr from transformers import ConvNextV2ForImageClassification from typing import Dict, Tuple MODEL_CHECKPOINTS = { "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth", "efficientnet_b0": "checkpoints/effnet_b0_best.pth", "efficientnet_b3": "checkpoints/effnet_b3_best.pth", "vit_b_16": "checkpoints/vit_b_16_best.pth" } DEFAULT_MODEL_NAME = "vit_b_16" MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {} DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class HFConvNeXtWrapper(nn.Module): def __init__(self, model_name, num_labels): super(HFConvNeXtWrapper, self).__init__() self.model = ConvNextV2ForImageClassification.from_pretrained( model_name, num_labels=num_labels, ignore_mismatched_sizes=True) def forward(self, x): return self.model(x).logits def get_model(model_name: str, num_classes: int) -> nn.Module: model = None if model_name == "efficientnet_b0": model = models.efficientnet_b0(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif model_name == "efficientnet_b3": model = models.efficientnet_b3(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif model_name == "vit_b_16": model = models.vit_b_16(weights=None) num_ftrs = model.heads.head.in_features model.heads.head = nn.Linear(num_ftrs, num_classes) elif "convnextv2" in model_name: model = HFConvNeXtWrapper(model_name, num_labels=num_classes) else: raise ValueError(f"Model '{model_name}' not supported.") return model def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) model_name_from_ckpt = checkpoint['model_name'] model = get_model(model_name_from_ckpt, num_classes=1) model.load_state_dict(checkpoint['state_dict']) model.to(device) model.eval() return model, {} print("--- Loading all models into memory ---") for display_name, ckpt_path in MODEL_CHECKPOINTS.items(): if os.path.exists(ckpt_path): model, _ = load_checkpoint(ckpt_path, DEVICE) MODELS[display_name] = model print(f"Loaded '{display_name}' on {DEVICE}.") else: print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.") if not MODELS: raise RuntimeError("No models were loaded. Please check your checkpoints directory.") with open('cm_config.yaml', 'r') as f: config = yaml.safe_load(f) IMG_SIZE = config['data_params']['image_size'] inference_transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(pil_image, model_name: str): if pil_image is None: return None model = MODELS[model_name] pil_image = pil_image.convert("RGB") image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(image_tensor) prob = torch.sigmoid(output).item() return {"clean": 1 - prob, "messy": prob} iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Dropdown( choices=list(MODELS.keys()), value=DEFAULT_MODEL_NAME, label="Select Model" ) ], outputs=gr.Label(num_top_classes=2, label="Predictions"), title="Messy vs Clean Image Classifier", description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.", ) iface.launch()