mnist-cids / app.py
Kadi-IAM's picture
Test
9aaef6e
import numpy as np
import cv2 as cv
from urllib.request import urlretrieve
import gradio as gr
import matplotlib.pyplot as plt
# urlretrieve("https://github.com/AyaanZaveri/mnist/raw/main/mnist-model.h5", "mnist-model.h5")
# model = tf.keras.models.load_model("mnist-model.h5")
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
import cids
import numpy as np
import tensorflow as tf
import seaborn as sns
from pathlib import Path
from tensorflow.keras import layers as klayers
from cids.tensorflow import layers as clayers
from cids.tensorflow.tuner import SearchResults
from cids.statistics import metrics
from kadi_ai import KadiAIProject
from matplotlib import pyplot as plt
from kerastuner import HyperParameters
################################################################################
# Controls
CHECK = False
SEARCH = False
USE_BEST_SEARCH_CONFIG = False
TRAIN = False
EVAL = True
PLOT = False
ANALYZE = False
TRAIN_CONTINUE = False
num_check_samples = 100
num_plot_samples = 20
num_principal_components = 3
################################################################################
# Data paths
project_name = "ex-mnist"
project_dir = Path.cwd() / "DATA" / project_name
project = KadiAIProject(project_name, root=project_dir)
# Read paths
# train_samples, valid_samples, test_samples = project.get_split_datasets(
# shuffle=True, valid_split=0.15, test_split=0.15
# )
test_samples = ["./DATA/ex-mnist/INPUTS/tfrecord/sample29579.tfrecord"]
################################################################################
# Data definition
data_definition = project.data_definition
data_definition.input_features = ["image"]
data_definition.output_features = ["label"]
################################################################################
# Neural network
# Model
def model_function(hp, data_definition):
# Hyper parameters
num_kernels = hp.Choice("num_kernels", [32, 64, 128, 256, 512], default=64)
dropout_rate = hp.Float("dropout", 0.0, 0.7, default=0.3) # not used
# Ref: https://github.com/AyaanZaveri/mnist/blob/main/MNIST_Number.ipynb
layers = []
layers.append(klayers.Conv2D(num_kernels, (3, 3), strides=(1, 1), padding="same"))
layers.append(klayers.ReLU())
layers.append(klayers.Conv2D(num_kernels, (3, 3), strides=(1, 1), padding="same"))
layers.append(klayers.ReLU())
# layers.append(klayers.Dropout(dropout_rate))
layers.append(klayers.MaxPooling2D(pool_size=(2, 2)))
layers.append(klayers.BatchNormalization())
layers.append(klayers.Conv2D(num_kernels*2, (3, 3), strides=(1, 1), padding="same"))
layers.append(klayers.ReLU())
layers.append(klayers.Conv2D(num_kernels*2, (3, 3), strides=(1, 1), padding="same"))
layers.append(klayers.ReLU())
layers.append(klayers.MaxPooling2D(pool_size=(2, 2)))
layers.append(klayers.BatchNormalization())
layers.append(klayers.Conv2D(num_kernels*4, (3, 3), strides=(1, 1), padding="same"))
layers.append(klayers.MaxPooling2D(pool_size=(2, 2)))
layers.append(klayers.Flatten())
layers.append(klayers.Dropout(dropout_rate))
layers.append(klayers.Dense(512))
layers.append(klayers.Dense(10, activation="softmax"))
return tf.keras.Sequential(layers)
# Set a model name
model_name = "mnist"
model_name += "--" + "--".join(
[
"+".join(list(data_definition.input_features)),
"+".join(list(data_definition.output_features)),
]
)
model_name += "--onehot"
saved_model_name = "mnist--image--label--onehot-default-64C3-64C3-MP2-128C3-128C3-MP2-256C3-MP2-512-10"
model = cids.CIDSModel.categorical_classification(
10,
data_definition,
model_function,
name=model_name,
identifier="default", # or "best"
result_dir=project.result_dir,
)
model.encode_categorical = "outputs"
model.metrics.append("accuracy")
model.monitor = "val_accuracy"
model.online_normalize = False
model.data_reader.prefetch = "cache"
# Load hp
import json
hp_path = os.path.join(model.base_model_dir, saved_model_name, "hp.json")
print("hp path", hp_path)
with open(hp_path, "r") as f:
saved_config = json.load(f)
hp = HyperParameters.from_config(saved_config)
if EVAL:
project.log(">> Evaluating...")
project.log(">>> Metrics...")
# Compute predictions
# test_loss = model.eval_data(
# test_samples, batch_size=4, checkpoint="last", hp=hp, submodels="generator")
print("hp", hp.values)
X, Y, Y_ = model.infer_data(
test_samples,
batch_size=4,
checkpoint="last",
hp=hp,
)
model = model.core_model
def recognize_digit(image):
image = cv.resize(image["composite"][:,:,-1], (28, 28))
# For debug, find out which chanel
# from PIL import Image
# for i in range(images["composite"].shape[-1]):
# image = images["composite"][:,:,i]
# im = Image.fromarray(image)
# im.save(f"c{i}.png")
# for i in range(images["background"].shape[-1]):
# image = images["background"][:,:,i]
# im = Image.fromarray(image)
# im.save(f"b{i}.png")
# print(image.shape)
# # image = image / 255
# # plt.imshow(image)
#from PIL import Image
#im = Image.fromarray(image)
#im.save(f"saved.png")
# image = image / 255
#print("max", image.max())
image = (image - 127.5) / 127.5
image = image.reshape((1, 28, 28))
prediction = model.predict(image)
prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)}
gr.Interface(fn=recognize_digit,
inputs="sketchpad",
outputs=gr.Label(num_top_classes=3),
live=True,
css=".footer {display:none !important}",
# title="MNIST Sketchpad",
description="A simple model trained on MNIST dataset using CIDS framework.\nDraw a single digit (0-9) in the center of the canvas.").launch(share=True)