FLL-innovation / app.py
nimitjalan's picture
fix bugs
9766b04
import os
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models
print("Building model architecture...")
# 1. Rebuild the exact backbone used in training
base_model = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet"
)
base_model.trainable = False
# 2. Rebuild the exact head used in training (Functional API)
inputs = tf.keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
# --- The Missing Layers ---
x = layers.BatchNormalization()(x)
x = layers.Dense(256, activation="relu")(x) # This matches your trained (1280, 256) weights
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.4)(x)
# --------------------------
outputs = layers.Dense(2, activation="softmax")(x)
model = models.Model(inputs, outputs)
print("Loading weights...")
# Now the shapes match perfectly
model.load_weights("model.keras")
print("Model ready!")
LABELS = ["preserved", "looted"]
IMG_SIZE = (224, 224)
def preprocess(img: Image.Image):
img = img.convert("RGB").resize(IMG_SIZE)
arr = np.array(img) / 255.0
arr = np.expand_dims(arr, 0).astype(np.float32)
return arr
def predict(image: Image.Image):
if image is None:
return None, "No image uploaded."
x = preprocess(image)
probs = model.predict(x)[0]
result = {LABELS[0]: float(probs[0]), LABELS[1]: float(probs[1])}
idx = int(np.argmax(probs))
summary = f"Prediction: **{LABELS[idx]}** ({float(probs[idx]):.1%} confidence)"
return result, summary
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload photo"),
outputs=[gr.Label(num_top_classes=2), gr.Textbox(label="Summary")],
title="Preserved vs Looted Classifier",
description="Upload a photo to classify archeological sites."
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))