Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['TRANSFORMERS_CACHE'] = '/data/.cache/transformers' | |
| os.environ['HF_HOME'] = '/data/.cache/huggingface' | |
| os.environ['MPLCONFIGDIR'] = '/data/.cache/matplotlib' | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ConvNextV2ForImageClassification | |
| from typing import Dict, Tuple | |
| MODEL_CHECKPOINTS = { | |
| "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth", | |
| "efficientnet_b0": "checkpoints/effnet_b0_best.pth", | |
| "efficientnet_b3": "checkpoints/effnet_b3_best.pth", | |
| "vit_b_16": "checkpoints/vit_b_16_best.pth" | |
| } | |
| DEFAULT_MODEL_NAME = "vit_b_16" | |
| MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {} | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class HFConvNeXtWrapper(nn.Module): | |
| def __init__(self, model_name, num_labels): | |
| super(HFConvNeXtWrapper, self).__init__() | |
| self.model = ConvNextV2ForImageClassification.from_pretrained( | |
| model_name, num_labels=num_labels, ignore_mismatched_sizes=True) | |
| def forward(self, x): | |
| return self.model(x).logits | |
| def get_model(model_name: str, num_classes: int) -> nn.Module: | |
| model = None | |
| if model_name == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "efficientnet_b3": | |
| model = models.efficientnet_b3(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "vit_b_16": | |
| model = models.vit_b_16(weights=None) | |
| num_ftrs = model.heads.head.in_features | |
| model.heads.head = nn.Linear(num_ftrs, num_classes) | |
| elif "convnextv2" in model_name: | |
| model = HFConvNeXtWrapper(model_name, num_labels=num_classes) | |
| else: | |
| raise ValueError(f"Model '{model_name}' not supported.") | |
| return model | |
| def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]: | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model_name_from_ckpt = checkpoint['model_name'] | |
| model = get_model(model_name_from_ckpt, num_classes=1) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.to(device) | |
| model.eval() | |
| return model, {} | |
| print("--- Loading all models into memory ---") | |
| for display_name, ckpt_path in MODEL_CHECKPOINTS.items(): | |
| if os.path.exists(ckpt_path): | |
| model, _ = load_checkpoint(ckpt_path, DEVICE) | |
| MODELS[display_name] = model | |
| print(f"Loaded '{display_name}' on {DEVICE}.") | |
| else: | |
| print(f"WARNING: Checkpoint for '{display_name}' not found. Skipping.") | |
| if not MODELS: | |
| raise RuntimeError("No models were loaded. Please check your checkpoints directory.") | |
| with open('cm_config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| IMG_SIZE = config['data_params']['image_size'] | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(pil_image, model_name: str): | |
| if pil_image is None: return None | |
| model = MODELS[model_name] | |
| pil_image = pil_image.convert("RGB") | |
| image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| prob = torch.sigmoid(output).item() | |
| return {"clean": 1 - prob, "messy": prob} | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL_NAME, | |
| label="Select Model" | |
| ) | |
| ], | |
| outputs=gr.Label(num_top_classes=2, label="Predictions"), | |
| title="Messy vs Clean Image Classifier", | |
| description="Upload an image and select a model to see its classification for 'messy' vs 'clean'.", | |
| ) | |
| iface.launch() |