faurielle commited on
Commit
53ab968
·
verified ·
1 Parent(s): 9381dd3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ from tensorflow.keras.applications import ResNet50
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
8
+ from tensorflow.keras.optimizers import Adam
9
+
10
+ # Laden der Validierungsdaten
11
+ validation_datagen = ImageDataGenerator(rescale=1./255)
12
+ validation_generator = validation_datagen.flow_from_directory(
13
+ r'C:\Coding\BlockImageClassification\validation',
14
+ target_size=(224, 224),
15
+ batch_size=8,
16
+ class_mode='categorical',
17
+ shuffle=True # Mischt die Daten vor jeder Epoche
18
+ )
19
+
20
+ # Laden des Modells
21
+ base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
22
+ x = base_model.output
23
+ x = GlobalAveragePooling2D()(x)
24
+ x = Dense(1024, activation='relu')(x)
25
+ predictions = Dense(3, activation='softmax')(x)
26
+ model = Model(inputs=base_model.input, outputs=predictions)
27
+
28
+ # Vorhersagen auf den Validierungsdaten
29
+ predictions_in_percentage = model.predict(validation_generator)
30
+ predictions = np.argmax(predictions_in_percentage, axis=-1)
31
+
32
+ # Darstellen der Vorhersagen
33
+ class_names = ['Bulbasaur', 'Charmander', 'Squirtle'] # Aktualisierte Klassennamen
34
+ for i in range(len(predictions)):
35
+ image, label = validation_generator[i]
36
+ plt.imshow(image[0])
37
+ plt.title('pred. ' + class_names[predictions[i]] + ' war ' + class_names[np.argmax(label)] + ' ' + str(np.round(predictions_in_percentage[i], 2)), fontsize=8)
38
+ plt.axis("off")
39
+ plt.show()
40
+