fischjos commited on
Commit
9c1a58a
·
verified ·
1 Parent(s): fb48e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -1,39 +1,40 @@
1
  import gradio as gr
2
- import numpy as np
3
  from PIL import Image
4
- from tensorflow.keras.models import load_model
5
- from tensorflow.keras.preprocessing.image import img_to_array
 
 
 
6
 
7
- # Assuming your model and class names are loaded correctly
8
- model = load_model('pokemon_classifier_model.keras')
9
- classes = ['Doduo', 'Geodude', 'Zubat']
10
 
 
11
  def classify_image(image):
12
  try:
13
- # Image preprocessing
14
- image = image.resize((150, 150))
15
- image_array = img_to_array(image)
16
- image_array = image_array.reshape((1, 150, 150, 3))
17
- image_array /= 255.0
18
-
19
- # Model prediction
20
  prediction = model.predict(image_array)
21
  predicted_class = classes[np.argmax(prediction)]
22
  confidence = np.max(prediction)
23
-
24
- return predicted_class, f"{confidence * 100:.2f}% Confidence"
25
  except Exception as e:
26
- # Catch and print any error that occurs
27
- print(f"Error during model prediction: {e}")
28
- return "Error in prediction", "Error"
29
 
30
- # Gradio app setup
31
- iface = gr.Interface(
32
- fn=classify_image,
33
- inputs=gr.Image(),
34
- outputs=[gr.Text(), gr.Text()],
35
- title="Pokémon Image Classifier",
36
- description="Upload an image of a Pokémon to classify!"
37
- )
38
 
39
- iface.launch()
 
1
  import gradio as gr
2
+ import tensorflow as tf
3
  from PIL import Image
4
+ import numpy as np
5
+
6
+ # Load the pre-trained Pokémon model
7
+ model_path = "pokemon_classifier_model.keras"
8
+ model = tf.keras.models.load_model(model_path)
9
 
10
+ # Define the Pokémon classes
11
+ classes = ['Doduo', 'Geodude', 'Zubat'] # Adjust classes based on what your model was trained on
 
12
 
13
+ # Define the image classification function
14
  def classify_image(image):
15
  try:
16
+ # Preprocess the image to match the model's input expectations
17
+ image = Image.fromarray(image.astype('uint8'), 'RGB') # Ensure image is in RGB
18
+ image = image.resize((150, 150)) # Resize to the input size your model expects
19
+ image_array = np.array(image) / 255.0 # Convert to array and normalize
20
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
21
+
22
+ # Predict using the model
23
  prediction = model.predict(image_array)
24
  predicted_class = classes[np.argmax(prediction)]
25
  confidence = np.max(prediction)
26
+
27
+ return f"Predicted Class: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
28
  except Exception as e:
29
+ return str(e) # Return the error message if failure
 
 
30
 
31
+ # Create Gradio interface
32
+ input_image = gr.inputs.Image(shape=(150, 150))
33
+ output_label = gr.outputs.Label(num_top_classes=3)
34
+ interface = gr.Interface(fn=classify_image,
35
+ inputs=input_image,
36
+ outputs=output_label,
37
+ examples=["path/to/example1.jpg", "path/to/example2.jpg"], # Update with real paths if needed
38
+ description="Upload an image of a Pokémon to classify!")
39
 
40
+ interface.launch()