nimitjalan commited on
Commit
9766b04
·
1 Parent(s): abf68a5
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -3,8 +3,11 @@ import gradio as gr
3
  import numpy as np
4
  from PIL import Image
5
  import tensorflow as tf
 
6
 
7
  print("Building model architecture...")
 
 
8
  base_model = tf.keras.applications.MobileNetV2(
9
  input_shape=(224, 224, 3),
10
  include_top=False,
@@ -12,16 +15,28 @@ base_model = tf.keras.applications.MobileNetV2(
12
  )
13
  base_model.trainable = False
14
 
15
- model = tf.keras.models.Sequential([
16
- base_model,
17
- tf.keras.layers.GlobalAveragePooling2D(),
18
- tf.keras.layers.Dense(128, activation="relu"),
19
- tf.keras.layers.Dropout(0.3),
20
- tf.keras.layers.Dense(2, activation="softmax")
21
- ])
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  print("Loading weights...")
24
- model.load_weights("model.keras")
 
25
  print("Model ready!")
26
 
27
  LABELS = ["preserved", "looted"]
@@ -34,6 +49,8 @@ def preprocess(img: Image.Image):
34
  return arr
35
 
36
  def predict(image: Image.Image):
 
 
37
  x = preprocess(image)
38
  probs = model.predict(x)[0]
39
 
@@ -48,8 +65,8 @@ iface = gr.Interface(
48
  inputs=gr.Image(type="pil", label="Upload photo"),
49
  outputs=[gr.Label(num_top_classes=2), gr.Textbox(label="Summary")],
50
  title="Preserved vs Looted Classifier",
51
- description="Upload a photo to classify."
52
  )
53
 
54
  if __name__ == "__main__":
55
- iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
3
  import numpy as np
4
  from PIL import Image
5
  import tensorflow as tf
6
+ from tensorflow.keras import layers, models
7
 
8
  print("Building model architecture...")
9
+
10
+ # 1. Rebuild the exact backbone used in training
11
  base_model = tf.keras.applications.MobileNetV2(
12
  input_shape=(224, 224, 3),
13
  include_top=False,
 
15
  )
16
  base_model.trainable = False
17
 
18
+ # 2. Rebuild the exact head used in training (Functional API)
19
+ inputs = tf.keras.Input(shape=(224, 224, 3))
20
+ x = base_model(inputs, training=False)
21
+ x = layers.GlobalAveragePooling2D()(x)
22
+
23
+ # --- The Missing Layers ---
24
+ x = layers.BatchNormalization()(x)
25
+ x = layers.Dense(256, activation="relu")(x) # This matches your trained (1280, 256) weights
26
+ x = layers.BatchNormalization()(x)
27
+ x = layers.Dropout(0.5)(x)
28
+ x = layers.Dense(128, activation="relu")(x)
29
+ x = layers.BatchNormalization()(x)
30
+ x = layers.Dropout(0.4)(x)
31
+ # --------------------------
32
+
33
+ outputs = layers.Dense(2, activation="softmax")(x)
34
+
35
+ model = models.Model(inputs, outputs)
36
 
37
  print("Loading weights...")
38
+ # Now the shapes match perfectly
39
+ model.load_weights("model.keras")
40
  print("Model ready!")
41
 
42
  LABELS = ["preserved", "looted"]
 
49
  return arr
50
 
51
  def predict(image: Image.Image):
52
+ if image is None:
53
+ return None, "No image uploaded."
54
  x = preprocess(image)
55
  probs = model.predict(x)[0]
56
 
 
65
  inputs=gr.Image(type="pil", label="Upload photo"),
66
  outputs=[gr.Label(num_top_classes=2), gr.Textbox(label="Summary")],
67
  title="Preserved vs Looted Classifier",
68
+ description="Upload a photo to classify archeological sites."
69
  )
70
 
71
  if __name__ == "__main__":
72
+ iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))