BlockImageClassification / visualize.py
faurielle's picture
Upload 48 files
4384a86 verified
raw
history blame
1.61 kB
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
# Laden der Validierungsdaten
validation_datagen = ImageDataGenerator(rescale=1./255)
validation_generator = validation_datagen.flow_from_directory(
r'C:\Coding\BlockImageClassification\validation',
target_size=(224, 224),
batch_size=8,
class_mode='categorical',
shuffle=True # Mischt die Daten vor jeder Epoche
)
# Laden des Modells
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(3, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# Vorhersagen auf den Validierungsdaten
predictions_in_percentage = model.predict(validation_generator)
predictions = np.argmax(predictions_in_percentage, axis=-1)
# Darstellen der Vorhersagen
class_names = ['Bulbasaur', 'Charmander', 'Squirtle'] # Aktualisierte Klassennamen
for i in range(len(predictions)):
image, label = validation_generator[i]
plt.imshow(image[0])
plt.title('pred. ' + class_names[predictions[i]] + ' war ' + class_names[np.argmax(label)] + ' ' + str(np.round(predictions_in_percentage[i], 2)), fontsize=8)
plt.axis("off")
plt.show()