Spaces:
Runtime error
Runtime error
File size: 7,171 Bytes
41e3e77 e3fec38 e3cb362 e3fec38 e3cb362 e3fec38 e3cb362 e3fec38 e3cb362 e3fec38 e3cb362 e3fec38 e3cb362 e3fec38 74e3cab e3cb362 74e3cab 41e3e77 e3cb362 41e3e77 e3cb362 41e3e77 e3cb362 41e3e77 e3cb362 e3fec38 74e3cab e3cb362 41e3e77 e3cb362 e3fec38 e3cb362 e3fec38 e3cb362 74e3cab e3cb362 74e3cab e3cb362 41e3e77 e3cb362 41e3e77 e3cb362 41e3e77 e3cb362 e3fec38 e3cb362 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
# !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!!
HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection"
MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file
CENTER_FILE = "center_c.pt" # Name of your center vector file
RADIUS_FILE = "radius_R.txt" # Name of your radius file
# Define the image size your model expects
INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128))
# Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here
ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available)
# --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) ---
# !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!!
class DeepSVDDNet(nn.Module):
"""
Placeholder class for the Deep SVDD network.
Replace this with the exact PyTorch class definition you used for training.
"""
def __init__(self):
super().__init__()
# Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST)
self.rep_dim = 32 # The dimension of the latent space
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
self.bn2 = nn.BatchNorm2d(4)
self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)
def forward(self, x):
x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
x = x.view(x.size(0), -1) # Flatten
x = self.fc1(x)
return x
# --- MODEL LOADING AND INITIALIZATION ---
def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file):
"""Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub."""
print("-> Starting model and asset loading...")
# 1. Download Model Weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
print(f" Downloaded model weights to: {model_path}")
# 2. Download Center Vector c
center_path = hf_hub_download(repo_id=repo_id, filename=center_file)
center_c = torch.load(center_path, map_location=torch.device('cpu'))
print(f" Downloaded center vector c from: {center_path}")
# 3. Download Radius R (stored as text/numpy or a simple file)
global R_SQUARED
try:
radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file)
with open(radius_path, 'r') as f:
R = float(f.read().strip())
R_SQUARED = R**2
print(f" Downloaded radius R ({R}). R^2 is: {R_SQUARED}")
except Exception as e:
print(f" Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}")
R_SQUARED = ANOMALY_THRESHOLD
# 4. Initialize and Load Model
model = DeepSVDDNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print("-> Model and assets successfully loaded and set to evaluation mode.")
return model, center_c
# Initialize global variables
try:
model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE)
except Exception as e:
print(f"FATAL ERROR during model initialization: {e}")
# Set placeholders to prevent crash
model = None
center_c = None
R_SQUARED = 1.0
# --- PREPROCESSING AND PREDICTION FUNCTION ---
def preprocess_image(pil_img: Image.Image) -> torch.Tensor:
"""Preprocesses a PIL image for Deep SVDD inference."""
# Define the transforms used during training
# NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std)
# based on how your model was trained. The example assumes 1 channel (grayscale).
transform = transforms.Compose([
transforms.Resize(INPUT_IMAGE_SIZE),
transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB
transforms.ToTensor(),
# Use the mean and std you trained with
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
# Apply transform and add a batch dimension
return transform(pil_img).unsqueeze(0)
def predict_anomaly(input_img: Image.Image):
"""
Performs Deep SVDD anomaly detection inference on the input image.
"""
if model is None or center_c is None:
return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.",
0.0,
"Failed to run")
try:
# 1. Preprocess image
x = preprocess_image(input_img)
# 2. Forward pass to get latent representation z
with torch.no_grad():
z = model(x)
# 3. Calculate anomaly score (squared distance to center c)
dist_squared = torch.sum((z - center_c)**2, dim=1).item()
# 4. Determine prediction based on R_SQUARED
is_anomaly = dist_squared > R_SQUARED
# 5. Format output
score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}"
prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)"
return score_text, dist_squared, prediction
except Exception as e:
error_msg = f"An error occurred during prediction: {e}"
print(error_msg)
return (error_msg, 0.0, "Failed to run")
# --- GRADIO INTERFACE ---
title = "Deep SVDD Anomaly Detection on Hugging Face"
description = (
"Upload an image to check if it is an anomaly based on the Deep SVDD model "
"loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)."
"<br>The model calculates the squared distance of the image's embedding from the "
"learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$."
f"<br>Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**"
)
# Use LaTeX syntax for the description
gr.Interface(
fn=predict_anomaly,
inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"),
outputs=[
gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."),
gr.Number(label="Raw Score (Distance Squared)", precision=4),
gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown
],
title=title,
description=description,
allow_flagging="never",
# Example images to test the model (replace with actual examples from your normal/anomaly classes)
examples=[
# If your model is trained on, say, normal dog images:
["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"],
["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"],
]
).launch() |