ZbindChi commited on
Commit
5369008
·
verified ·
1 Parent(s): 2852ca0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -29
app.py CHANGED
@@ -3,42 +3,40 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- #!pip install tensorflow tensorflow-datasets gradio pillow matplotlib
7
-
8
  model_path = "pokemon-model_transferlearning.keras"
9
  model = tf.keras.models.load_model(model_path)
10
 
11
- from PIL import Image
12
- import numpy as np
13
- import tensorflow as tf
14
 
15
- # Define the core prediction function
16
- def predict_pokemon(image):
17
- # Preprocess image
18
- image = Image.fromarray(image.astype('uint8')) # Convert numpy array to PIL image
19
- image = image.resize((150, 150)) # Resize the image to 150x150
20
- image = np.array(image)
21
- image = np.expand_dims(image, axis=0) # Add batch dimension
22
 
23
- # Predict
24
- prediction = model.predict(image)
 
 
25
 
26
- # Apply softmax to get probabilities for each class
27
- probabilities = tf.nn.softmax(prediction)
 
 
28
 
29
- # Map probabilities to Pokemon classes
30
- pokemon_classes = ['Articuno', 'Bulbasaur', 'Charmander']
31
- probabilities_dict = {pokemon_class: round(float(probability), 2) for pokemon_class, probability in zip(pokemon_classes, probabilities[0])}
32
 
33
- return probabilities_dict
34
 
35
- # Create the Gradio interface
36
- input_image = gr.Image()
37
  iface = gr.Interface(
38
- fn=predict_pokemon,
39
- inputs=input_image,
40
- outputs=gr.Label(),
41
- live=True,
42
- examples=["images/01.jpg", "images/02.png", "images/03.png", "images/04.jpg", "images/06.png", "images/06.png"],
43
- description="A simple mlp classification model for image classification using the mnist dataset.")
44
- iface.launch()
 
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # Pfad zum gespeicherten Modell
 
7
  model_path = "pokemon-model_transferlearning.keras"
8
  model = tf.keras.models.load_model(model_path)
9
 
10
+ # Definieren der Klassennamen
11
+ labels = ['Articuno', 'Bulbasaur', 'Charmander']
 
12
 
13
+ # Funktion zur Klassifizierung
14
+ def classify_pokemon(image):
15
+ if image is None:
16
+ return {"Error": "No image uploaded"}
 
 
 
17
 
18
+ # Bildvorverarbeitung
19
+ image = Image.fromarray(image).resize((150, 150))
20
+ image = np.array(image) / 255.0
21
+ image = np.expand_dims(image, axis=0)
22
 
23
+ # Vorhersage
24
+ prediction = model.predict(image)
25
+ predicted_class = np.argmax(prediction[0])
26
+ confidence = np.max(prediction[0])
27
 
28
+ # Konfidenzwerte
29
+ confidences = {labels[i]: np.round(float(prediction[0][i]), 2) for i in range(len(labels))}
 
30
 
31
+ return confidences, f"Predicted: {labels[predicted_class]}, Confidence: {confidence:.2f}"
32
 
33
+ # Erstellen einer Gradio-Schnittstelle
 
34
  iface = gr.Interface(
35
+ fn=classify_pokemon,
36
+ inputs=gr.Image(),
37
+ outputs=["label", "text"],
38
+ live=True
39
+ )
40
+
41
+ # Starten der Schnittstelle
42
+ iface.launch()