Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models | |
| print("Building model architecture...") | |
| # 1. Rebuild the exact backbone used in training | |
| base_model = tf.keras.applications.MobileNetV2( | |
| input_shape=(224, 224, 3), | |
| include_top=False, | |
| weights="imagenet" | |
| ) | |
| base_model.trainable = False | |
| # 2. Rebuild the exact head used in training (Functional API) | |
| inputs = tf.keras.Input(shape=(224, 224, 3)) | |
| x = base_model(inputs, training=False) | |
| x = layers.GlobalAveragePooling2D()(x) | |
| # --- The Missing Layers --- | |
| x = layers.BatchNormalization()(x) | |
| x = layers.Dense(256, activation="relu")(x) # This matches your trained (1280, 256) weights | |
| x = layers.BatchNormalization()(x) | |
| x = layers.Dropout(0.5)(x) | |
| x = layers.Dense(128, activation="relu")(x) | |
| x = layers.BatchNormalization()(x) | |
| x = layers.Dropout(0.4)(x) | |
| # -------------------------- | |
| outputs = layers.Dense(2, activation="softmax")(x) | |
| model = models.Model(inputs, outputs) | |
| print("Loading weights...") | |
| # Now the shapes match perfectly | |
| model.load_weights("model.keras") | |
| print("Model ready!") | |
| LABELS = ["preserved", "looted"] | |
| IMG_SIZE = (224, 224) | |
| def preprocess(img: Image.Image): | |
| img = img.convert("RGB").resize(IMG_SIZE) | |
| arr = np.array(img) / 255.0 | |
| arr = np.expand_dims(arr, 0).astype(np.float32) | |
| return arr | |
| def predict(image: Image.Image): | |
| if image is None: | |
| return None, "No image uploaded." | |
| x = preprocess(image) | |
| probs = model.predict(x)[0] | |
| result = {LABELS[0]: float(probs[0]), LABELS[1]: float(probs[1])} | |
| idx = int(np.argmax(probs)) | |
| summary = f"Prediction: **{LABELS[idx]}** ({float(probs[idx]):.1%} confidence)" | |
| return result, summary | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload photo"), | |
| outputs=[gr.Label(num_top_classes=2), gr.Textbox(label="Summary")], | |
| title="Preserved vs Looted Classifier", | |
| description="Upload a photo to classify archeological sites." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |