|
|
|
|
|
import tensorflow as tf |
|
|
import matplotlib.pyplot as plt |
|
|
import os |
|
|
import pathlib |
|
|
import numpy as np |
|
|
data_dir = pathlib.Path("/Users/rosh/Downloads/Train_data") |
|
|
class_names = np.array(sorted([item.name for item in data_dir.glob("*")])) |
|
|
class_names = list(class_names) |
|
|
class_names.pop(0) |
|
|
loaded_model = tf.keras.models.load_model('model_4_improved_8.h5') |
|
|
def load_and_prep_image(filename, img_shape=224): |
|
|
""" |
|
|
Reads an image from filename, turns it into a tensor |
|
|
and reshapes it to (img_shape, img_shape, colour_channel). |
|
|
""" |
|
|
|
|
|
img = tf.io.read_file(filename) |
|
|
|
|
|
|
|
|
|
|
|
img = tf.image.decode_image(img, channels=3) |
|
|
|
|
|
|
|
|
img = tf.image.resize(img, size = [img_shape, img_shape]) |
|
|
|
|
|
|
|
|
img = img/255. |
|
|
return img |
|
|
|
|
|
|
|
|
def pred_and_plot(model, filename, class_names): |
|
|
""" |
|
|
Imports an image located at filename, makes a prediction on it with |
|
|
a trained model and plots the image with the predicted class as the title. |
|
|
""" |
|
|
|
|
|
img = load_and_prep_image(filename) |
|
|
|
|
|
|
|
|
pred = model.predict(tf.expand_dims(img, axis=0)) |
|
|
|
|
|
|
|
|
|
|
|
pred_class = class_names[pred.argmax()] |
|
|
|
|
|
|
|
|
plt.imshow(img) |
|
|
plt.title(f"Prediction: {pred_class}") |
|
|
plt.axis(False) |
|
|
plt.show() |
|
|
|
|
|
pred_and_plot(loaded_model, "/Users/rosh/Downloads/egret.jpg", class_names) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|