import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
# !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!!
HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection"
MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file
CENTER_FILE = "center_c.pt" # Name of your center vector file
RADIUS_FILE = "radius_R.txt" # Name of your radius file
# Define the image size your model expects
INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128))
# Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here
ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available)
# --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) ---
# !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!!
class DeepSVDDNet(nn.Module):
"""
Placeholder class for the Deep SVDD network.
Replace this with the exact PyTorch class definition you used for training.
"""
def __init__(self):
super().__init__()
# Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST)
self.rep_dim = 32 # The dimension of the latent space
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
self.bn2 = nn.BatchNorm2d(4)
self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)
def forward(self, x):
x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
x = x.view(x.size(0), -1) # Flatten
x = self.fc1(x)
return x
# --- MODEL LOADING AND INITIALIZATION ---
def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file):
"""Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub."""
print("-> Starting model and asset loading...")
# 1. Download Model Weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
print(f" Downloaded model weights to: {model_path}")
# 2. Download Center Vector c
center_path = hf_hub_download(repo_id=repo_id, filename=center_file)
center_c = torch.load(center_path, map_location=torch.device('cpu'))
print(f" Downloaded center vector c from: {center_path}")
# 3. Download Radius R (stored as text/numpy or a simple file)
global R_SQUARED
try:
radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file)
with open(radius_path, 'r') as f:
R = float(f.read().strip())
R_SQUARED = R**2
print(f" Downloaded radius R ({R}). R^2 is: {R_SQUARED}")
except Exception as e:
print(f" Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}")
R_SQUARED = ANOMALY_THRESHOLD
# 4. Initialize and Load Model
model = DeepSVDDNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print("-> Model and assets successfully loaded and set to evaluation mode.")
return model, center_c
# Initialize global variables
try:
model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE)
except Exception as e:
print(f"FATAL ERROR during model initialization: {e}")
# Set placeholders to prevent crash
model = None
center_c = None
R_SQUARED = 1.0
# --- PREPROCESSING AND PREDICTION FUNCTION ---
def preprocess_image(pil_img: Image.Image) -> torch.Tensor:
"""Preprocesses a PIL image for Deep SVDD inference."""
# Define the transforms used during training
# NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std)
# based on how your model was trained. The example assumes 1 channel (grayscale).
transform = transforms.Compose([
transforms.Resize(INPUT_IMAGE_SIZE),
transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB
transforms.ToTensor(),
# Use the mean and std you trained with
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
# Apply transform and add a batch dimension
return transform(pil_img).unsqueeze(0)
def predict_anomaly(input_img: Image.Image):
"""
Performs Deep SVDD anomaly detection inference on the input image.
"""
if model is None or center_c is None:
return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.",
0.0,
"Failed to run")
try:
# 1. Preprocess image
x = preprocess_image(input_img)
# 2. Forward pass to get latent representation z
with torch.no_grad():
z = model(x)
# 3. Calculate anomaly score (squared distance to center c)
dist_squared = torch.sum((z - center_c)**2, dim=1).item()
# 4. Determine prediction based on R_SQUARED
is_anomaly = dist_squared > R_SQUARED
# 5. Format output
score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}"
prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)"
return score_text, dist_squared, prediction
except Exception as e:
error_msg = f"An error occurred during prediction: {e}"
print(error_msg)
return (error_msg, 0.0, "Failed to run")
# --- GRADIO INTERFACE ---
title = "Deep SVDD Anomaly Detection on Hugging Face"
description = (
"Upload an image to check if it is an anomaly based on the Deep SVDD model "
"loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)."
"
The model calculates the squared distance of the image's embedding from the "
"learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$."
f"
Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**"
)
# Use LaTeX syntax for the description
gr.Interface(
fn=predict_anomaly,
inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"),
outputs=[
gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."),
gr.Number(label="Raw Score (Distance Squared)", precision=4),
gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown
],
title=title,
description=description,
allow_flagging="never",
# Example images to test the model (replace with actual examples from your normal/anomaly classes)
examples=[
# If your model is trained on, say, normal dog images:
["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"],
["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"],
]
).launch()