isana25 commited on
Commit
fbb8092
·
verified ·
1 Parent(s): c0df4cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -1,30 +1,29 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
- import json
5
  from PIL import Image
6
 
7
- # Load model and class names from local repo files
8
  model = tf.keras.models.load_model("animal_classifier.keras")
9
 
10
- with open("class_names.json", "r") as f:
11
- class_names = json.load(f)
12
 
13
- def predict(image: Image.Image):
14
- image = image.resize((224, 224))
15
- img_array = np.array(image)
16
  img_array = np.expand_dims(img_array, axis=0)
17
- img_array = img_array / 255.0 # normalize if used in training
18
 
19
  preds = model.predict(img_array)
20
- pred_class = np.argmax(preds, axis=1)[0]
21
- return class_names[pred_class]
22
 
23
- demo = gr.Interface(fn=predict,
24
- inputs=gr.Image(type="pil"),
25
- outputs="text",
26
- title="MobileNetV2 Animal Classifier",
27
- description="Upload an image to classify the animal.")
28
 
29
- if __name__ == "__main__":
30
- demo.launch()
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
 
4
  from PIL import Image
5
 
6
+ # Load your model from Hugging Face repo (or local file if testing)
7
  model = tf.keras.models.load_model("animal_classifier.keras")
8
 
9
+ # Dynamically get class names from model output layer
10
+ class_names = list(model.class_names) if hasattr(model, 'class_names') else ["dog", "lion", "tiger"] # fallback if not stored in model
11
 
12
+ def predict_image(image):
13
+ img = image.resize((224, 224))
14
+ img_array = np.array(img) / 255.0
15
  img_array = np.expand_dims(img_array, axis=0)
 
16
 
17
  preds = model.predict(img_array)
18
+ confidence = np.max(preds)
19
+ predicted_index = np.argmax(preds)
20
 
21
+ threshold = 0.5
22
+ if confidence < threshold:
23
+ return "Image not recognized as any animal in the dataset"
24
+ else:
25
+ return class_names[predicted_index]
26
 
27
+ demo = gr.Interface(fn=predict_image, inputs=gr.Image(type="pil"), outputs="text", title="MobileNetV2 Animal Classifier")
28
+
29
+ demo.launch()