import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import os from huggingface_hub import hf_hub_download # --- CONFIGURATION --- # !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!! HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection" MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file CENTER_FILE = "center_c.pt" # Name of your center vector file RADIUS_FILE = "radius_R.txt" # Name of your radius file # Define the image size your model expects INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128)) # Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available) # --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) --- # !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!! class DeepSVDDNet(nn.Module): """ Placeholder class for the Deep SVDD network. Replace this with the exact PyTorch class definition you used for training. """ def __init__(self): super().__init__() # Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST) self.rep_dim = 32 # The dimension of the latent space self.pool = nn.MaxPool2d(2, 2) self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) self.bn1 = nn.BatchNorm2d(8) self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) self.bn2 = nn.BatchNorm2d(4) self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False) def forward(self, x): x = self.pool(nn.functional.relu(self.bn1(self.conv1(x)))) x = self.pool(nn.functional.relu(self.bn2(self.conv2(x)))) x = x.view(x.size(0), -1) # Flatten x = self.fc1(x) return x # --- MODEL LOADING AND INITIALIZATION --- def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file): """Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub.""" print("-> Starting model and asset loading...") # 1. Download Model Weights model_path = hf_hub_download(repo_id=repo_id, filename=model_file) print(f" Downloaded model weights to: {model_path}") # 2. Download Center Vector c center_path = hf_hub_download(repo_id=repo_id, filename=center_file) center_c = torch.load(center_path, map_location=torch.device('cpu')) print(f" Downloaded center vector c from: {center_path}") # 3. Download Radius R (stored as text/numpy or a simple file) global R_SQUARED try: radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file) with open(radius_path, 'r') as f: R = float(f.read().strip()) R_SQUARED = R**2 print(f" Downloaded radius R ({R}). R^2 is: {R_SQUARED}") except Exception as e: print(f" Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}") R_SQUARED = ANOMALY_THRESHOLD # 4. Initialize and Load Model model = DeepSVDDNet() model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() print("-> Model and assets successfully loaded and set to evaluation mode.") return model, center_c # Initialize global variables try: model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE) except Exception as e: print(f"FATAL ERROR during model initialization: {e}") # Set placeholders to prevent crash model = None center_c = None R_SQUARED = 1.0 # --- PREPROCESSING AND PREDICTION FUNCTION --- def preprocess_image(pil_img: Image.Image) -> torch.Tensor: """Preprocesses a PIL image for Deep SVDD inference.""" # Define the transforms used during training # NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std) # based on how your model was trained. The example assumes 1 channel (grayscale). transform = transforms.Compose([ transforms.Resize(INPUT_IMAGE_SIZE), transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB transforms.ToTensor(), # Use the mean and std you trained with transforms.Normalize(mean=(0.5,), std=(0.5,)) ]) # Apply transform and add a batch dimension return transform(pil_img).unsqueeze(0) def predict_anomaly(input_img: Image.Image): """ Performs Deep SVDD anomaly detection inference on the input image. """ if model is None or center_c is None: return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.", 0.0, "Failed to run") try: # 1. Preprocess image x = preprocess_image(input_img) # 2. Forward pass to get latent representation z with torch.no_grad(): z = model(x) # 3. Calculate anomaly score (squared distance to center c) dist_squared = torch.sum((z - center_c)**2, dim=1).item() # 4. Determine prediction based on R_SQUARED is_anomaly = dist_squared > R_SQUARED # 5. Format output score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}" prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)" return score_text, dist_squared, prediction except Exception as e: error_msg = f"An error occurred during prediction: {e}" print(error_msg) return (error_msg, 0.0, "Failed to run") # --- GRADIO INTERFACE --- title = "Deep SVDD Anomaly Detection on Hugging Face" description = ( "Upload an image to check if it is an anomaly based on the Deep SVDD model " "loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)." "
The model calculates the squared distance of the image's embedding from the " "learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$." f"
Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**" ) # Use LaTeX syntax for the description gr.Interface( fn=predict_anomaly, inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"), outputs=[ gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."), gr.Number(label="Raw Score (Distance Squared)", precision=4), gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown ], title=title, description=description, allow_flagging="never", # Example images to test the model (replace with actual examples from your normal/anomaly classes) examples=[ # If your model is trained on, say, normal dog images: ["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"], ["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"], ] ).launch()