isana25 commited on
Commit
89a2c48
·
verified ·
1 Parent(s): 2f588ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -1,28 +1,30 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
- from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
5
  from PIL import Image
6
- import os
7
 
8
- # Load the trained model (make sure the file is in the same directory)
9
  model = tf.keras.models.load_model("animal_classifier.keras")
10
 
11
- # Define class names (same order as training)
12
- train_dir = "split_animals/train"
13
- class_names = sorted(os.listdir(train_dir))
14
 
15
- # Prediction function
16
- def predict_image(img):
17
- img = img.resize((224, 224))
18
- img_array = tf.keras.utils.img_to_array(img)
19
  img_array = np.expand_dims(img_array, axis=0)
20
- img_array = preprocess_input(img_array)
21
 
22
- prediction = model.predict(img_array)
23
- predicted_class = class_names[np.argmax(prediction)]
24
- return f"Predicted Animal: {predicted_class}"
25
 
26
- # Launch Gradio app
27
- app = gr.Interface(fn=predict_image, inputs=gr.Image(type="pil"), outputs="text", title="🧠🐾 TF-WildNet: Animal Classifier")
28
- app.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
+ import json
5
  from PIL import Image
 
6
 
7
+ # Load model and class names from local repo files
8
  model = tf.keras.models.load_model("animal_classifier.keras")
9
 
10
+ with open("class_names.json", "r") as f:
11
+ class_names = json.load(f)
 
12
 
13
+ def predict(image: Image.Image):
14
+ image = image.resize((224, 224))
15
+ img_array = np.array(image)
 
16
  img_array = np.expand_dims(img_array, axis=0)
17
+ img_array = img_array / 255.0 # normalize if used in training
18
 
19
+ preds = model.predict(img_array)
20
+ pred_class = np.argmax(preds, axis=1)[0]
21
+ return class_names[pred_class]
22
 
23
+ demo = gr.Interface(fn=predict,
24
+ inputs=gr.Image(type="pil"),
25
+ outputs="text",
26
+ title="MobileNetV2 Animal Classifier",
27
+ description="Upload an image to classify the animal.")
28
+
29
+ if __name__ == "__main__":
30
+ demo.launch()