import torch import torch.nn as nn import yaml from torchvision import models, transforms from PIL import Image import gradio as gr import os CONFIG_PATH = 'staging_config.yaml' CHECKPOINT_FILENAME = 'model.pt' def get_model(model_name, num_classes): """Factory function to create a model shell for loading weights.""" model = None if model_name == "efficientnet_b0": model = models.efficientnet_b0(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif model_name == "resnet50": model = models.resnet50(weights=None) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes) elif model_name == "vit_b_16": model = models.vit_b_16(weights=None) num_ftrs = model.heads.head.in_features model.heads.head = nn.Linear(num_ftrs, num_classes) else: raise ValueError(f"Model '{model_name}' not supported.") return model def load_checkpoint(checkpoint_path, device): """Loads a checkpoint and returns the model and class mapping.""" if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) model_name = checkpoint['model_name'] class_to_idx = checkpoint['class_to_idx'] model = get_model(model_name, num_classes=1) model.load_state_dict(checkpoint['state_dict']) model.to(device) model.eval() idx_to_class = {v: k for k, v in class_to_idx.items()} return model, idx_to_class try: with open(CONFIG_PATH, 'r') as f: config = yaml.safe_load(f) except FileNotFoundError: raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, DEVICE) print(f"Model loaded successfully on {DEVICE}.") print(f"Class mapping: {IDX_TO_CLASS}") IMG_SIZE = config['data_params']['image_size'] inference_transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(pil_image): """ Runs prediction on a single PIL image and returns a dictionary of class probabilities. Gradio's `Label` component expects a dictionary format for its output. """ if pil_image is None: return None pil_image = pil_image.convert("RGB") image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = MODEL(image_tensor) prob = torch.sigmoid(output).item() class_0_name = IDX_TO_CLASS.get(0, "Class 0") class_1_name = IDX_TO_CLASS.get(1, "Class 1") confidences = { class_0_name: 1 - prob, class_1_name: prob } return confidences title = "Image Classifier API" description = """ Upload an image and the model will predict its class. This model was trained to distinguish between two classes. The API returns the probabilities for each class. """ iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Label(num_top_classes=2, label="Predictions"), title=title, description=description, ) iface.launch()