MultivexAI commited on
Commit
ec15aec
·
verified ·
1 Parent(s): 39b401a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from PIL import Image
7
+ from scipy.ndimage import rotate, gaussian_filter
8
+ import gradio as gr
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ hf_hub_download(repo_id="MultivexAI/RobustMNIST-v1.0", filename="model.py", local_dir=".")
12
+ hf_hub_download(repo_id="MultivexAI/RobustMNIST-v1.0", filename="model.pt", local_dir=".")
13
+
14
+ sys.path.append(os.path.abspath("."))
15
+ from model import HierarchicalNetwork
16
+
17
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model = HierarchicalNetwork(out_dims=11).to(DEVICE)
19
+ model.load_state_dict(torch.load("model.pt", map_location=DEVICE))
20
+ model.eval()
21
+
22
+ def preprocess_and_predict(sketch_data, rotation_val, noise_val, blur_val):
23
+ if sketch_data is None:
24
+ return None, {}
25
+
26
+ if isinstance(sketch_data, dict):
27
+ img_array = sketch_data.get("composite", None)
28
+ if img_array is None:
29
+ layers = sketch_data.get("layers", [])
30
+ img_array = layers[0] if layers else None
31
+ else:
32
+ img_array = sketch_data
33
+
34
+ if img_array is None:
35
+ return None, {}
36
+
37
+ pil_img = Image.fromarray(img_array.astype('uint8'))
38
+ if pil_img.mode == 'RGBA':
39
+ _, _, _, alpha_channel = pil_img.split()
40
+ if np.array(alpha_channel).max() > 0:
41
+ gray_img = alpha_channel
42
+ else:
43
+ gray_img = pil_img.convert('L')
44
+ else:
45
+ gray_img = pil_img.convert('L')
46
+
47
+ resized_img = gray_img.resize((28, 28), Image.Resampling.LANCZOS)
48
+ np_img = np.array(resized_img).astype(np.float32)
49
+
50
+ border_average = (np_img[0, :].mean() + np_img[-1, :].mean() + np_img[:, 0].mean() + np_img[:, -1].mean()) / 4.0
51
+ if border_average > 127.5:
52
+ np_img = 255.0 - np_img
53
+
54
+ if rotation_val > 0:
55
+ np_img = rotate(np_img, rotation_val, reshape=False, order=1, mode='constant', cval=0.0)
56
+
57
+ if blur_val > 0:
58
+ np_img = gaussian_filter(np_img, sigma=blur_val)
59
+
60
+ if noise_val > 0:
61
+ variance_scale = noise_val * 255.0
62
+ additive_noise = np.random.normal(0, variance_scale, np_img.shape)
63
+ np_img = np.clip(np_img + additive_noise, 0.0, 255.0)
64
+
65
+ normalized_array = np_img / 255.0
66
+ tensor_input = torch.tensor(normalized_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
67
+
68
+ with torch.inference_mode():
69
+ logits = model(tensor_input)
70
+ probabilities = F.softmax(logits, dim=1).cpu().numpy()[0]
71
+
72
+ class_labels = [str(i) for i in range(10)] + ["Unknown"]
73
+ distribution = {class_labels[i]: float(probabilities[i]) for i in range(11)}
74
+ preview_output = np.clip(np_img, 0, 255).astype(np.uint8)
75
+
76
+ return preview_output, distribution
77
+
78
+ with gr.Blocks(title="Robust MNIST Classifier") as interface:
79
+ gr.Markdown("## Robust Hierarchical Classifier")
80
+ gr.Markdown("Draw a single digit, adjust the sliders to apply synthetic environmental distortions, and observe the robustness profile.")
81
+
82
+ with gr.Row():
83
+ with gr.Column():
84
+ canvas = gr.Sketchpad(
85
+ label="Draw Digit",
86
+ type="numpy",
87
+ crop_to_bbox=False
88
+ )
89
+ rotation = gr.Slider(minimum=0, maximum=180, value=0, step=1, label="Rotation Angle (Degrees)")
90
+ noise = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Gaussian Noise Level")
91
+ blur = gr.Slider(minimum=0.0, maximum=5.0, value=0.0, step=0.1, label="Gaussian Blur (Sigma)")
92
+ run_btn = gr.Button("Evaluate Signature", variant="primary")
93
+
94
+ with gr.Column():
95
+ preview = gr.Image(label="Model-View Reconstruction (28x28)", image_mode="L")
96
+ probabilities_output = gr.Label(num_top_classes=5, label="Probability Map Output")
97
+
98
+ run_btn.click(
99
+ fn=preprocess_and_predict,
100
+ inputs=[canvas, rotation, noise, blur],
101
+ outputs=[preview, probabilities_output]
102
+ )
103
+
104
+ if __name__ == "__main__":
105
+ interface.launch()