Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # --- CONFIGURATION --- | |
| # !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!! | |
| HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection" | |
| MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file | |
| CENTER_FILE = "center_c.pt" # Name of your center vector file | |
| RADIUS_FILE = "radius_R.txt" # Name of your radius file | |
| # Define the image size your model expects | |
| INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128)) | |
| # Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here | |
| ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available) | |
| # --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) --- | |
| # !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!! | |
| class DeepSVDDNet(nn.Module): | |
| """ | |
| Placeholder class for the Deep SVDD network. | |
| Replace this with the exact PyTorch class definition you used for training. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| # Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST) | |
| self.rep_dim = 32 # The dimension of the latent space | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) | |
| self.bn1 = nn.BatchNorm2d(8) | |
| self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) | |
| self.bn2 = nn.BatchNorm2d(4) | |
| self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False) | |
| def forward(self, x): | |
| x = self.pool(nn.functional.relu(self.bn1(self.conv1(x)))) | |
| x = self.pool(nn.functional.relu(self.bn2(self.conv2(x)))) | |
| x = x.view(x.size(0), -1) # Flatten | |
| x = self.fc1(x) | |
| return x | |
| # --- MODEL LOADING AND INITIALIZATION --- | |
| def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file): | |
| """Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub.""" | |
| print("-> Starting model and asset loading...") | |
| # 1. Download Model Weights | |
| model_path = hf_hub_download(repo_id=repo_id, filename=model_file) | |
| print(f" Downloaded model weights to: {model_path}") | |
| # 2. Download Center Vector c | |
| center_path = hf_hub_download(repo_id=repo_id, filename=center_file) | |
| center_c = torch.load(center_path, map_location=torch.device('cpu')) | |
| print(f" Downloaded center vector c from: {center_path}") | |
| # 3. Download Radius R (stored as text/numpy or a simple file) | |
| global R_SQUARED | |
| try: | |
| radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file) | |
| with open(radius_path, 'r') as f: | |
| R = float(f.read().strip()) | |
| R_SQUARED = R**2 | |
| print(f" Downloaded radius R ({R}). R^2 is: {R_SQUARED}") | |
| except Exception as e: | |
| print(f" Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}") | |
| R_SQUARED = ANOMALY_THRESHOLD | |
| # 4. Initialize and Load Model | |
| model = DeepSVDDNet() | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| print("-> Model and assets successfully loaded and set to evaluation mode.") | |
| return model, center_c | |
| # Initialize global variables | |
| try: | |
| model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE) | |
| except Exception as e: | |
| print(f"FATAL ERROR during model initialization: {e}") | |
| # Set placeholders to prevent crash | |
| model = None | |
| center_c = None | |
| R_SQUARED = 1.0 | |
| # --- PREPROCESSING AND PREDICTION FUNCTION --- | |
| def preprocess_image(pil_img: Image.Image) -> torch.Tensor: | |
| """Preprocesses a PIL image for Deep SVDD inference.""" | |
| # Define the transforms used during training | |
| # NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std) | |
| # based on how your model was trained. The example assumes 1 channel (grayscale). | |
| transform = transforms.Compose([ | |
| transforms.Resize(INPUT_IMAGE_SIZE), | |
| transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB | |
| transforms.ToTensor(), | |
| # Use the mean and std you trained with | |
| transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| # Apply transform and add a batch dimension | |
| return transform(pil_img).unsqueeze(0) | |
| def predict_anomaly(input_img: Image.Image): | |
| """ | |
| Performs Deep SVDD anomaly detection inference on the input image. | |
| """ | |
| if model is None or center_c is None: | |
| return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.", | |
| 0.0, | |
| "Failed to run") | |
| try: | |
| # 1. Preprocess image | |
| x = preprocess_image(input_img) | |
| # 2. Forward pass to get latent representation z | |
| with torch.no_grad(): | |
| z = model(x) | |
| # 3. Calculate anomaly score (squared distance to center c) | |
| dist_squared = torch.sum((z - center_c)**2, dim=1).item() | |
| # 4. Determine prediction based on R_SQUARED | |
| is_anomaly = dist_squared > R_SQUARED | |
| # 5. Format output | |
| score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}" | |
| prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)" | |
| return score_text, dist_squared, prediction | |
| except Exception as e: | |
| error_msg = f"An error occurred during prediction: {e}" | |
| print(error_msg) | |
| return (error_msg, 0.0, "Failed to run") | |
| # --- GRADIO INTERFACE --- | |
| title = "Deep SVDD Anomaly Detection on Hugging Face" | |
| description = ( | |
| "Upload an image to check if it is an anomaly based on the Deep SVDD model " | |
| "loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)." | |
| "<br>The model calculates the squared distance of the image's embedding from the " | |
| "learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$." | |
| f"<br>Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**" | |
| ) | |
| # Use LaTeX syntax for the description | |
| gr.Interface( | |
| fn=predict_anomaly, | |
| inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"), | |
| outputs=[ | |
| gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."), | |
| gr.Number(label="Raw Score (Distance Squared)", precision=4), | |
| gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown | |
| ], | |
| title=title, | |
| description=description, | |
| allow_flagging="never", | |
| # Example images to test the model (replace with actual examples from your normal/anomaly classes) | |
| examples=[ | |
| # If your model is trained on, say, normal dog images: | |
| ["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"], | |
| ["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"], | |
| ] | |
| ).launch() |