fischjos commited on
Commit
fb48e9e
·
verified ·
1 Parent(s): 9ee833f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -1,34 +1,39 @@
1
  import gradio as gr
2
- import tensorflow as tf
3
- from PIL import Image
4
  import numpy as np
5
- from tensorflow.keras.preprocessing.image import img_to_array
6
  from tensorflow.keras.models import load_model
 
7
 
8
- # Load the pre-trained model
9
  model = load_model('pokemon_classifier_model.keras')
10
- classes = ['Doduo', 'Geodude', 'Zubat'] # List of classes
11
 
12
  def classify_image(image):
13
- """Function to classify the image using the pre-trained model."""
14
- image = image.resize((150, 150)) # Resize image to match model's expected input
15
- image_array = img_to_array(image)
16
- image_array = image_array.reshape((1, 150, 150, 3)) # Reshape for model
17
- image_array /= 255.0 # Normalize the image
18
-
19
- prediction = model.predict(image_array)
20
- predicted_class = classes[np.argmax(prediction)]
21
- confidence = np.max(prediction)
22
- return predicted_class, f"{confidence * 100:.2f}% Confidence"
 
 
 
 
 
 
 
23
 
24
- # Create a Gradio interface
25
  iface = gr.Interface(
26
  fn=classify_image,
27
  inputs=gr.Image(),
28
- outputs=[gr.Label(num_top_classes=3), gr.Text()],
29
  title="Pokémon Image Classifier",
30
  description="Upload an image of a Pokémon to classify!"
31
  )
32
 
33
- # Launch the Gradio app
34
- iface.launch()
 
1
  import gradio as gr
 
 
2
  import numpy as np
3
+ from PIL import Image
4
  from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import img_to_array
6
 
7
+ # Assuming your model and class names are loaded correctly
8
  model = load_model('pokemon_classifier_model.keras')
9
+ classes = ['Doduo', 'Geodude', 'Zubat']
10
 
11
  def classify_image(image):
12
+ try:
13
+ # Image preprocessing
14
+ image = image.resize((150, 150))
15
+ image_array = img_to_array(image)
16
+ image_array = image_array.reshape((1, 150, 150, 3))
17
+ image_array /= 255.0
18
+
19
+ # Model prediction
20
+ prediction = model.predict(image_array)
21
+ predicted_class = classes[np.argmax(prediction)]
22
+ confidence = np.max(prediction)
23
+
24
+ return predicted_class, f"{confidence * 100:.2f}% Confidence"
25
+ except Exception as e:
26
+ # Catch and print any error that occurs
27
+ print(f"Error during model prediction: {e}")
28
+ return "Error in prediction", "Error"
29
 
30
+ # Gradio app setup
31
  iface = gr.Interface(
32
  fn=classify_image,
33
  inputs=gr.Image(),
34
+ outputs=[gr.Text(), gr.Text()],
35
  title="Pokémon Image Classifier",
36
  description="Upload an image of a Pokémon to classify!"
37
  )
38
 
39
+ iface.launch()