faurielle commited on
Commit
9381dd3
·
verified ·
1 Parent(s): 7d6afe7

Delete visualize.py

Browse files
Files changed (1) hide show
  1. visualize.py +0 -40
visualize.py DELETED
@@ -1,40 +0,0 @@
1
- import tensorflow as tf
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
- from tensorflow.keras.applications import ResNet50
6
- from tensorflow.keras.models import Model
7
- from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
8
- from tensorflow.keras.optimizers import Adam
9
-
10
- # Laden der Validierungsdaten
11
- validation_datagen = ImageDataGenerator(rescale=1./255)
12
- validation_generator = validation_datagen.flow_from_directory(
13
- r'C:\Coding\BlockImageClassification\validation',
14
- target_size=(224, 224),
15
- batch_size=8,
16
- class_mode='categorical',
17
- shuffle=True # Mischt die Daten vor jeder Epoche
18
- )
19
-
20
- # Laden des Modells
21
- base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
22
- x = base_model.output
23
- x = GlobalAveragePooling2D()(x)
24
- x = Dense(1024, activation='relu')(x)
25
- predictions = Dense(3, activation='softmax')(x)
26
- model = Model(inputs=base_model.input, outputs=predictions)
27
-
28
- # Vorhersagen auf den Validierungsdaten
29
- predictions_in_percentage = model.predict(validation_generator)
30
- predictions = np.argmax(predictions_in_percentage, axis=-1)
31
-
32
- # Darstellen der Vorhersagen
33
- class_names = ['Bulbasaur', 'Charmander', 'Squirtle'] # Aktualisierte Klassennamen
34
- for i in range(len(predictions)):
35
- image, label = validation_generator[i]
36
- plt.imshow(image[0])
37
- plt.title('pred. ' + class_names[predictions[i]] + ' war ' + class_names[np.argmax(label)] + ' ' + str(np.round(predictions_in_percentage[i], 2)), fontsize=8)
38
- plt.axis("off")
39
- plt.show()
40
-