badenlt commited on
Commit
da37721
·
verified ·
1 Parent(s): 97b3779

asdflkj

Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+
6
+ # ---- Paths to model + stats (put these files in the same directory) ----
7
+ MODEL_PATH = "small32cnn_mlbt_mmat.keras"
8
+ STATS_PATH = "preproc_stats_smallcnn.npz"
9
+
10
+ # ---- Load model and stats ----
11
+ model = tf.keras.models.load_model(MODEL_PATH, compile=False)
12
+
13
+ stats = np.load(STATS_PATH)
14
+ MEAN = float(stats["MEAN"])
15
+ STD = float(stats["STD"])
16
+
17
+ CLASS_NAMES = ["MMAT", "MLBT"] # 0 -> MMAT, 1 -> MLBT
18
+
19
+
20
+ # ---- Preprocessing: Image -> (1, 32, 32, 1) normalized ----
21
+ def preprocess(img: np.ndarray) -> np.ndarray:
22
+ """
23
+ img: H x W x C in [0, 255] from Gradio (numpy)
24
+ returns: (1, 32, 32, 1) float32, z-score normalized
25
+ """
26
+
27
+ # Convert to grayscale
28
+ if img.ndim == 3 and img.shape[2] == 3:
29
+ # RGB -> grayscale
30
+ img_gray = np.dot(img[..., :3], [0.2989, 0.5870, 0.1140])
31
+ elif img.ndim == 3 and img.shape[2] == 1:
32
+ img_gray = img[..., 0]
33
+ elif img.ndim == 2:
34
+ img_gray = img
35
+ else:
36
+ # fallback: take first channel
37
+ img_gray = img[..., 0]
38
+
39
+ # Resize to 32x32
40
+ pil_img = Image.fromarray(img_gray.astype("uint8"))
41
+ pil_img = pil_img.resize((32, 32), Image.BILINEAR)
42
+ arr = np.array(pil_img).astype("float32")
43
+
44
+ # z-score normalization using training stats
45
+ arr = (arr - MEAN) / STD
46
+
47
+ # Add batch and channel dims: (1, 32, 32, 1)
48
+ arr = arr[None, ..., None]
49
+ return arr
50
+
51
+
52
+ # ---- Prediction function for Gradio ----
53
+ def predict(img: np.ndarray):
54
+ """
55
+ Returns a dict {class_name: probability} for gr.Label
56
+ """
57
+ x = preprocess(img)
58
+ probs = model.predict(x, verbose=0)[0, 0] # scalar prob for class 1 (MLBT)
59
+ p_mlbt = float(probs)
60
+ p_mmat = float(1.0 - p_mlbt)
61
+
62
+ return {
63
+ "MLBT": p_mlbt,
64
+ "MMAT": p_mmat,
65
+ }
66
+
67
+
68
+ # ---- Gradio interface ----
69
+ input_component = gr.Image(
70
+ type="numpy",
71
+ label="Jet image (grayscale or RGB, any size ≥ 32x32)"
72
+ )
73
+
74
+ output_component = gr.Label(
75
+ num_top_classes=2,
76
+ label="Predicted probabilities"
77
+ )
78
+
79
+ examples = [] # you can drop some example jet PNGs here later
80
+
81
+ demo = gr.Interface(
82
+ fn=predict,
83
+ inputs=input_component,
84
+ outputs=output_component,
85
+ title="MLBT vs MMAT Jet Classifier (Small 32×32 CNN)",
86
+ description=(
87
+ "Upload a jet image (32×32 heatmap or larger) and this model "
88
+ "predicts whether it came from the MLBT or MMAT energy-loss module."
89
+ ),
90
+ examples=examples
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()