File size: 7,171 Bytes
41e3e77
 
 
e3fec38
e3cb362
 
e3fec38
e3cb362
 
 
 
 
 
 
 
e3fec38
e3cb362
 
 
 
e3fec38
e3cb362
 
e3fec38
e3cb362
 
 
 
e3fec38
e3cb362
 
 
 
 
 
 
 
 
e3fec38
74e3cab
e3cb362
 
 
 
74e3cab
41e3e77
e3cb362
41e3e77
e3cb362
 
 
41e3e77
e3cb362
 
 
41e3e77
e3cb362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3fec38
 
74e3cab
e3cb362
 
41e3e77
e3cb362
e3fec38
e3cb362
e3fec38
e3cb362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74e3cab
e3cb362
 
74e3cab
e3cb362
 
41e3e77
e3cb362
41e3e77
e3cb362
41e3e77
e3cb362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3fec38
 
e3cb362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download

# --- CONFIGURATION ---
# !!! IMPORTANT: REPLACE THESE WITH YOUR MODEL'S DETAILS !!!
HUGGINGFACE_REPO_ID = "YOUR_USERNAME/YOUR_REPO_NAME" # e.g., "myusername/deepsvdd-anomalydetection"
MODEL_WEIGHTS_FILE = "net_weights.pth" # Name of your model weights file
CENTER_FILE = "center_c.pt"            # Name of your center vector file
RADIUS_FILE = "radius_R.txt"           # Name of your radius file

# Define the image size your model expects
INPUT_IMAGE_SIZE = (28, 28) # Replace with the size used during training (e.g., (128, 128))
# Threshold (R) is loaded dynamically, but you can set a default anomaly score threshold here
ANOMALY_THRESHOLD = 0.5 # A default score threshold (will be replaced by R if RADIUS_FILE is available)

# --- DEEP SVDD MODEL ARCHITECTURE (Placeholder) ---
# !!! IMPORTANT: YOU MUST REPLACE THIS CLASS WITH YOUR ACTUAL DEEP SVDD NETWORK DEFINITION !!!
class DeepSVDDNet(nn.Module):
    """
    Placeholder class for the Deep SVDD network.
    Replace this with the exact PyTorch class definition you used for training.
    """
    def __init__(self):
        super().__init__()
        # Example structure for a 28x28 grayscale input (like MNIST/FashionMNIST)
        self.rep_dim = 32 # The dimension of the latent space
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2)
        self.bn2 = nn.BatchNorm2d(4)
        self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.bn1(self.conv1(x))))
        x = self.pool(nn.functional.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1) # Flatten
        x = self.fc1(x)
        return x

# --- MODEL LOADING AND INITIALIZATION ---

def load_deepsvdd_assets(repo_id, model_file, center_file, radius_file):
    """Loads model weights, center 'c', and radius 'R' from the Hugging Face Hub."""
    print("-> Starting model and asset loading...")

    # 1. Download Model Weights
    model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
    print(f"   Downloaded model weights to: {model_path}")

    # 2. Download Center Vector c
    center_path = hf_hub_download(repo_id=repo_id, filename=center_file)
    center_c = torch.load(center_path, map_location=torch.device('cpu'))
    print(f"   Downloaded center vector c from: {center_path}")

    # 3. Download Radius R (stored as text/numpy or a simple file)
    global R_SQUARED
    try:
        radius_path = hf_hub_download(repo_id=repo_id, filename=radius_file)
        with open(radius_path, 'r') as f:
            R = float(f.read().strip())
        R_SQUARED = R**2
        print(f"   Downloaded radius R ({R}). R^2 is: {R_SQUARED}")
    except Exception as e:
        print(f"   Warning: Could not load radius R from {radius_file}. Using default anomaly threshold ({ANOMALY_THRESHOLD}). Error: {e}")
        R_SQUARED = ANOMALY_THRESHOLD


    # 4. Initialize and Load Model
    model = DeepSVDDNet()
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    print("-> Model and assets successfully loaded and set to evaluation mode.")
    return model, center_c

# Initialize global variables
try:
    model, center_c = load_deepsvdd_assets(HUGGINGFACE_REPO_ID, MODEL_WEIGHTS_FILE, CENTER_FILE, RADIUS_FILE)
except Exception as e:
    print(f"FATAL ERROR during model initialization: {e}")
    # Set placeholders to prevent crash
    model = None
    center_c = None
    R_SQUARED = 1.0


# --- PREPROCESSING AND PREDICTION FUNCTION ---

def preprocess_image(pil_img: Image.Image) -> torch.Tensor:
    """Preprocesses a PIL image for Deep SVDD inference."""
    # Define the transforms used during training
    # NOTE: Adjust channel handling (1 or 3) and normalization values (mean/std)
    # based on how your model was trained. The example assumes 1 channel (grayscale).
    transform = transforms.Compose([
        transforms.Resize(INPUT_IMAGE_SIZE),
        transforms.Grayscale(num_output_channels=1), # Change to 3 if using RGB
        transforms.ToTensor(),
        # Use the mean and std you trained with
        transforms.Normalize(mean=(0.5,), std=(0.5,))
    ])
    # Apply transform and add a batch dimension
    return transform(pil_img).unsqueeze(0)

def predict_anomaly(input_img: Image.Image):
    """
    Performs Deep SVDD anomaly detection inference on the input image.
    """
    if model is None or center_c is None:
        return (f"Error: Model failed to load. Please check your HUGGINGFACE_REPO_ID ({HUGGINGFACE_REPO_ID}) and file paths.",
                0.0,
                "Failed to run")

    try:
        # 1. Preprocess image
        x = preprocess_image(input_img)

        # 2. Forward pass to get latent representation z
        with torch.no_grad():
            z = model(x)

        # 3. Calculate anomaly score (squared distance to center c)
        dist_squared = torch.sum((z - center_c)**2, dim=1).item()
        
        # 4. Determine prediction based on R_SQUARED
        is_anomaly = dist_squared > R_SQUARED
        
        # 5. Format output
        score_text = f"Anomaly Score (Distance²): {dist_squared:.4f}"
        prediction = "ANOMALY (Outlier)" if is_anomaly else "NORMAL (Inlier)"
        
        return score_text, dist_squared, prediction

    except Exception as e:
        error_msg = f"An error occurred during prediction: {e}"
        print(error_msg)
        return (error_msg, 0.0, "Failed to run")


# --- GRADIO INTERFACE ---

title = "Deep SVDD Anomaly Detection on Hugging Face"
description = (
    "Upload an image to check if it is an anomaly based on the Deep SVDD model "
    "loaded from your Hugging Face repository (`" + HUGGINGFACE_REPO_ID + "`)."
    "<br>The model calculates the squared distance of the image's embedding from the "
    "learned center $\\mathbf{c}$ and compares it to the radius threshold $R^2$."
    f"<br>Current Anomaly Threshold ($R^2$): **{R_SQUARED:.4f}**"
)

# Use LaTeX syntax for the description
gr.Interface(
    fn=predict_anomaly,
    inputs=gr.Image(type="pil", label=f"Input Image (will be resized to {INPUT_IMAGE_SIZE})"),
    outputs=[
        gr.Textbox(label="Anomaly Score Details", value="Upload an image to start."),
        gr.Number(label="Raw Score (Distance Squared)", precision=4),
        gr.Textbox(label="Prediction", value="Pending", render=False) # Prediction text is hidden, score details are shown
    ],
    title=title,
    description=description,
    allow_flagging="never",
    # Example images to test the model (replace with actual examples from your normal/anomaly classes)
    examples=[
        # If your model is trained on, say, normal dog images:
        ["https://placehold.co/128x128/99D9EA/white?text=Normal+Dog"],
        ["https://placehold.co/128x128/F0B27A/white?text=Anomaly+Cat"],
    ]
).launch()