| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| CONFIG_PATH = 'staging_config.yaml' | |
| CHECKPOINT_FILENAME = 'model.pt' | |
| def get_model(model_name, num_classes): | |
| """Factory function to create a model shell for loading weights.""" | |
| model = None | |
| if model_name == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "resnet50": | |
| model = models.resnet50(weights=None) | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "vit_b_16": | |
| model = models.vit_b_16(weights=None) | |
| num_ftrs = model.heads.head.in_features | |
| model.heads.head = nn.Linear(num_ftrs, num_classes) | |
| else: | |
| raise ValueError(f"Model '{model_name}' not supported.") | |
| return model | |
| def load_checkpoint(checkpoint_path, device): | |
| """Loads a checkpoint and returns the model and class mapping.""" | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model_name = checkpoint['model_name'] | |
| class_to_idx = checkpoint['class_to_idx'] | |
| model = get_model(model_name, num_classes=1) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.to(device) | |
| model.eval() | |
| idx_to_class = {v: k for k, v in class_to_idx.items()} | |
| return model, idx_to_class | |
| try: | |
| with open(CONFIG_PATH, 'r') as f: | |
| config = yaml.safe_load(f) | |
| except FileNotFoundError: | |
| raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, DEVICE) | |
| print(f"Model loaded successfully on {DEVICE}.") | |
| print(f"Class mapping: {IDX_TO_CLASS}") | |
| IMG_SIZE = config['data_params']['image_size'] | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(pil_image): | |
| """ | |
| Runs prediction on a single PIL image and returns a dictionary of class probabilities. | |
| Gradio's `Label` component expects a dictionary format for its output. | |
| """ | |
| if pil_image is None: | |
| return None | |
| pil_image = pil_image.convert("RGB") | |
| image_tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| output = MODEL(image_tensor) | |
| prob = torch.sigmoid(output).item() | |
| class_0_name = IDX_TO_CLASS.get(0, "Class 0") | |
| class_1_name = IDX_TO_CLASS.get(1, "Class 1") | |
| confidences = { | |
| class_0_name: 1 - prob, | |
| class_1_name: prob | |
| } | |
| return confidences | |
| title = "Image Classifier API" | |
| description = """ | |
| Upload an image and the model will predict its class. | |
| This model was trained to distinguish between two classes. | |
| The API returns the probabilities for each class. | |
| """ | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=2, label="Predictions"), | |
| title=title, | |
| description=description, | |
| ) | |
| iface.launch() |