Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import keras | |
| from PIL import Image | |
| import gradio as gr | |
| MODEL_PATH = "small32cnn_mlbt_mmat.keras" | |
| STATS_PATH = "jet_image_scale_and_stats.npz" | |
| model = keras.models.load_model(MODEL_PATH, compile=False) | |
| stats = np.load(STATS_PATH) | |
| SCALE = float(stats["SCALE"]) | |
| MEAN = float(stats["MEAN"]) | |
| STD = float(stats["STD"]) | |
| print("Loaded SCALE/MEAN/STD:", SCALE, MEAN, STD, flush=True) | |
| CLASS_NAMES = ["MMAT", "MLBT"] | |
| # ---- Preprocessing: Image -> (1, 32, 32, 1) normalized ---- | |
| def preprocess(img: np.ndarray) -> np.ndarray: | |
| if img.ndim == 3 and img.shape[2] == 3: | |
| img_gray = np.dot(img[..., :3], [0.2989, 0.5870, 0.1140]) | |
| elif img.ndim == 3 and img.shape[2] == 1: | |
| img_gray = img[..., 0] | |
| elif img.ndim == 2: | |
| img_gray = img | |
| else: | |
| img_gray = img[..., 0] | |
| pil_img = Image.fromarray(img_gray.astype("uint8")) | |
| pil_img = pil_img.resize((32, 32), Image.BILINEAR) | |
| arr = np.array(pil_img).astype("float32") | |
| # Invert the global scaling to approximate original X | |
| arr_unscaled = arr / SCALE | |
| # Now apply the same normalization as during training | |
| arr_norm = (arr_unscaled - MEAN) / (STD + 1e-8) | |
| arr_norm = arr_norm[None, ..., None] | |
| return arr_norm | |
| # ---- Prediction function for Gradio ---- | |
| def predict(img: np.ndarray): | |
| x = preprocess(img) | |
| raw = float(model.predict(x, verbose=0)[0, 0]) | |
| print("Raw model output:", raw, flush=True) | |
| # raw ≈ P(MLBT) as in training | |
| prob_mlbt = raw | |
| prob_mmat = 1.0 - prob_mlbt | |
| return {"MLBT": prob_mlbt, "MMAT": prob_mmat} | |
| # ---- Gradio interface ---- | |
| input_component = gr.Image( | |
| type="numpy", | |
| label="Jet image (grayscale or RGB, any size ≥ 32x32)" | |
| ) | |
| output_component = gr.Label( | |
| num_top_classes=2, | |
| label="Predicted probabilities" | |
| ) | |
| examples = [] # you can drop some example jet PNGs here later | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=input_component, | |
| outputs=output_component, | |
| title="MLBT vs MMAT Jet Classifier (Small 32×32 CNN)", | |
| description=( | |
| "Upload a jet image (32×32 heatmap or larger) and this model " | |
| "predicts whether it came from the MLBT or MMAT energy-loss module." | |
| ), | |
| examples=examples | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |