ash12321's picture
Update app.py
e3cb362 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
# --- CONFIGURATION ---
# !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!!
HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection"
MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file
CENTER_FILE = "center_c.pt" # Name of your center vector file
RADIUS_FILE = "radius_R.txt" # Name of your radius file
# Define the image size your model expects
INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128))
# Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here
ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available)
# --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) ---
# !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!!
class DeepSVDDNet(nn.Module):
"""
Placeholder class for the Deep SVDD network.
Replace this with the exact PyTorch class definition you used for training.
"""
def __init__(self):
super().__init__()
# Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST)
self.rep_dim = 32 # The dimension of the latent space
self.pool = nn.MaxPool2d(2, 2)
self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
self.bn1 = nn.BatchNorm2d(8)
self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
self.bn2 = nn.BatchNorm2d(4)
self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)
def forward(self, x):
x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
x = x.view(x.size(0), -1) # Flatten
x = self.fc1(x)
return x
# --- MODEL LOADING AND INITIALIZATION ---
def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file):
"""Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub."""
print("-> Starting model and asset loading...")
# 1. Download Model Weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
print(f" Downloaded model weights to: {model_path}")
# 2. Download Center Vector c
center_path = hf_hub_download(repo_id=repo_id, filename=center_file)
center_c = torch.load(center_path, map_location=torch.device('cpu'))
print(f" Downloaded center vector c from: {center_path}")
# 3. Download Radius R (stored as text/numpy or a simple file)
global R_SQUARED
try:
radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file)
with open(radius_path, 'r') as f:
R = float(f.read().strip())
R_SQUARED = R**2
print(f" Downloaded radius R ({R}). R^2 is: {R_SQUARED}")
except Exception as e:
print(f" Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}")
R_SQUARED = ANOMALY_THRESHOLD
# 4. Initialize and Load Model
model = DeepSVDDNet()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
print("-> Model and assets successfully loaded and set to evaluation mode.")
return model, center_c
# Initialize global variables
try:
model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE)
except Exception as e:
print(f"FATAL ERROR during model initialization: {e}")
# Set placeholders to prevent crash
model = None
center_c = None
R_SQUARED = 1.0
# --- PREPROCESSING AND PREDICTION FUNCTION ---
def preprocess_image(pil_img: Image.Image) -> torch.Tensor:
"""Preprocesses a PIL image for Deep SVDD inference."""
# Define the transforms used during training
# NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std)
# based on how your model was trained. The example assumes 1 channel (grayscale).
transform = transforms.Compose([
transforms.Resize(INPUT_IMAGE_SIZE),
transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB
transforms.ToTensor(),
# Use the mean and std you trained with
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
# Apply transform and add a batch dimension
return transform(pil_img).unsqueeze(0)
def predict_anomaly(input_img: Image.Image):
"""
Performs Deep SVDD anomaly detection inference on the input image.
"""
if model is None or center_c is None:
return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.",
0.0,
"Failed to run")
try:
# 1. Preprocess image
x = preprocess_image(input_img)
# 2. Forward pass to get latent representation z
with torch.no_grad():
z = model(x)
# 3. Calculate anomaly score (squared distance to center c)
dist_squared = torch.sum((z - center_c)**2, dim=1).item()
# 4. Determine prediction based on R_SQUARED
is_anomaly = dist_squared > R_SQUARED
# 5. Format output
score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}"
prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)"
return score_text, dist_squared, prediction
except Exception as e:
error_msg = f"An error occurred during prediction: {e}"
print(error_msg)
return (error_msg, 0.0, "Failed to run")
# --- GRADIO INTERFACE ---
title = "Deep SVDD Anomaly Detection on Hugging Face"
description = (
"Upload an image to check if it is an anomaly based on the Deep SVDD model "
"loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)."
"<br>The model calculates the squared distance of the image's embedding from the "
"learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$."
f"<br>Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**"
)
# Use LaTeX syntax for the description
gr.Interface(
fn=predict_anomaly,
inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"),
outputs=[
gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."),
gr.Number(label="Raw Score (Distance Squared)", precision=4),
gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown
],
title=title,
description=description,
allow_flagging="never",
# Example images to test the model (replace with actual examples from your normal/anomaly classes)
examples=[
# If your model is trained on, say, normal dog images:
["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"],
["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"],
]
).launch()