fischjos commited on
Commit
9c488ae
·
verified ·
1 Parent(s): 5c2b803

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -44
app.py CHANGED
@@ -1,51 +1,32 @@
1
- import os
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
  import tensorflow as tf
5
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
6
  from tensorflow.keras.models import load_model
7
 
8
- # Suppress TensorFlow logging and warnings
9
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logging (1 = INFO, 2 = WARNING)
10
- tf.get_logger().setLevel('ERROR')
11
- tf.autograph.set_verbosity(2)
12
-
13
- # Function to load the pre-trained model
14
- def load_pretrained_model(model_path):
15
- return load_model(model_path)
16
-
17
  # Load the pre-trained model
18
- model_path = 'pokemon_classifier_model.keras' # Adjust the path if your model is located elsewhere
19
- model = load_pretrained_model(model_path)
20
 
21
- # Define the Pokémon classes to be classified
22
- classes = ['Doduo', 'Geodude', 'Zubat']
 
 
 
 
 
 
 
 
 
23
 
24
- # Data Augmentation and Data Generators for the predictions
25
- datagen = ImageDataGenerator(rescale=1./255)
26
- generator = datagen.flow_from_directory(
27
- 'pokemon_dataset', # Adjust this path to the directory containing your data
28
- target_size=(150, 150),
29
- batch_size=1, # For prediction, typically we use batch size of 1
30
- classes=classes,
31
- class_mode='categorical',
32
- shuffle=False)
33
 
34
- # Predict and visualize results
35
- for images, labels in generator:
36
- # Use only the first 5 images for visualization
37
- for i in range(5):
38
- img = images[i]
39
- label = labels[i]
40
- plt.imshow(img)
41
- # Create prediction
42
- img_array = np.expand_dims(img, axis=0) # Model expects a batch
43
- predictions = model.predict(img_array)
44
- predicted_class_index = np.argmax(predictions, axis=1)[0]
45
- predicted_class_name = classes[predicted_class_index]
46
- # Determine actual class
47
- true_class_index = np.argmax(label)
48
- true_class_name = classes[true_class_index]
49
- plt.title(f'Predicted: {predicted_class_name}, True: {true_class_name}')
50
- plt.show()
51
- break # Use only the first batch of images
 
1
+ import gradio as gr
 
 
2
  import tensorflow as tf
3
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
4
  from tensorflow.keras.models import load_model
5
 
 
 
 
 
 
 
 
 
 
6
  # Load the pre-trained model
7
+ model = load_model('pokemon_classifier_model.keras')
8
+ classes = ['Doduo', 'Geodude', 'Zubat'] # List of classes
9
 
10
+ def classify_image(image):
11
+ """Function to classify the image using the pre-trained model."""
12
+ image = image.resize((150, 150)) # Resize image to match model's expected input
13
+ image_array = img_to_array(image)
14
+ image_array = image_array.reshape((1, 150, 150, 3)) # Reshape for model
15
+ image_array /= 255.0 # Normalize the image
16
+
17
+ prediction = model.predict(image_array)
18
+ predicted_class = classes[np.argmax(prediction)]
19
+ confidence = np.max(prediction)
20
+ return predicted_class, f"{confidence * 100:.2f}% Confidence"
21
 
22
+ # Create a Gradio interface
23
+ iface = gr.Interface(
24
+ classify_image,
25
+ inputs=gr.inputs.Image(shape=(150, 150)),
26
+ outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Text()],
27
+ title="Pokémon Image Classifier",
28
+ description="Upload an image of a Pokémon to classify!"
29
+ )
 
30
 
31
+ # Launch the Gradio app
32
+ iface.launch()