chiraant commited on
Commit
90b23c7
·
verified ·
1 Parent(s): 8102324

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -39
app.py CHANGED
@@ -11,42 +11,29 @@ model = tf.keras.models.load_model(model_path)
11
 
12
  labels = ['Abra', 'Ditto', 'Gengar']
13
 
14
- def predict_pokemons(image):
15
- # Preprocess the image
16
- print(type(image))
17
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
18
- image = image.resize((150, 150)) # Resize the image as per the model's input requirement
19
- image = np.array(image)
20
- image = np.expand_dims(image, axis=0) # Add batch dimension
21
-
22
- # Predict
23
- predictions = model.predict(image)
24
-
25
- # Convert the logits to probabilities using softmax
26
- probabilities = tf.nn.softmax(predictions[0]).numpy()
27
-
28
- # Create a dictionary to hold the probabilities for each class
29
- results = {class_names[i]: float(np.round(probabilities[i], 2)) for i in range(len(class_names))}
30
- return results
31
-
32
- # Define regression function
33
- def predict_regression(image):
34
- # Preprocess image
35
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
36
- image = image.resize((150, 150))
37
- image = np.array(image)
38
- print(image.shape)
39
- # Predict
40
- prediction = model.predict(image[None, ...]) # Assuming single regression value
41
- print(prediction)
42
- confidences = str(prediction)
43
- return confidences
44
-
45
- # Create Gradio interface
46
- input_image = gr.Image()
47
- output_text = gr.Textbox(label="Predicted Value")
48
- interface = gr.Interface(fn=predict_pokemons,
49
- inputs=input_image,
50
- outputs=gr.Label(),
51
- description="A simple pokemon classification model based on Xception and Pokemon Images (https://www.kaggle.com/datasets/mikoajkolman/pokemon-images-first-generation17000-files).")
52
- interface.launch()
 
11
 
12
  labels = ['Abra', 'Ditto', 'Gengar']
13
 
14
+ def predict_pokemon_type(uploaded_file):
15
+ if uploaded_file is None:
16
+ return "No file uploaded.", None, "No prediction"
17
+
18
+ # Load the image from the file path
19
+ with Image.open(uploaded_file) as img:
20
+ img = img.resize((150, 150))
21
+ img_array = np.array(img)
22
+
23
+ prediction = model.predict(np.expand_dims(img_array, axis=0))
24
+
25
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
26
+
27
+ return img, confidences
28
+
29
+ # Define the Gradio interface
30
+ iface = gr.Interface(
31
+ fn=predict_pokemon_type,
32
+ inputs=gr.File(label="Upload File"),
33
+ outputs=["image", "text"],
34
+ title="Pokemon Classifier",
35
+ description="Upload a picture of a Pokemon (preferably Cubone, Ditto, Psyduck, Snorlax, or Weedle) to see its type and confidence level. The trained model has a test accuracy of 99.17%!"
36
+ )
37
+
38
+ # Launch the interface
39
+ iface.launch()