Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import cv2 as cv | |
| from urllib.request import urlretrieve | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| # urlretrieve("https://github.com/AyaanZaveri/mnist/raw/main/mnist-model.h5", "mnist-model.h5") | |
| # model = tf.keras.models.load_model("mnist-model.h5") | |
| import os | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" | |
| import cids | |
| import numpy as np | |
| import tensorflow as tf | |
| import seaborn as sns | |
| from pathlib import Path | |
| from tensorflow.keras import layers as klayers | |
| from cids.tensorflow import layers as clayers | |
| from cids.tensorflow.tuner import SearchResults | |
| from cids.statistics import metrics | |
| from kadi_ai import KadiAIProject | |
| from matplotlib import pyplot as plt | |
| from kerastuner import HyperParameters | |
| ################################################################################ | |
| # Controls | |
| CHECK = False | |
| SEARCH = False | |
| USE_BEST_SEARCH_CONFIG = False | |
| TRAIN = False | |
| EVAL = True | |
| PLOT = False | |
| ANALYZE = False | |
| TRAIN_CONTINUE = False | |
| num_check_samples = 100 | |
| num_plot_samples = 20 | |
| num_principal_components = 3 | |
| ################################################################################ | |
| # Data paths | |
| project_name = "ex-mnist" | |
| project_dir = Path.cwd() / "DATA" / project_name | |
| project = KadiAIProject(project_name, root=project_dir) | |
| # Read paths | |
| # train_samples, valid_samples, test_samples = project.get_split_datasets( | |
| # shuffle=True, valid_split=0.15, test_split=0.15 | |
| # ) | |
| test_samples = ["./DATA/ex-mnist/INPUTS/tfrecord/sample29579.tfrecord"] | |
| ################################################################################ | |
| # Data definition | |
| data_definition = project.data_definition | |
| data_definition.input_features = ["image"] | |
| data_definition.output_features = ["label"] | |
| ################################################################################ | |
| # Neural network | |
| # Model | |
| def model_function(hp, data_definition): | |
| # Hyper parameters | |
| num_kernels = hp.Choice("num_kernels", [32, 64, 128, 256, 512], default=64) | |
| dropout_rate = hp.Float("dropout", 0.0, 0.7, default=0.3) # not used | |
| # Ref: https://github.com/AyaanZaveri/mnist/blob/main/MNIST_Number.ipynb | |
| layers = [] | |
| layers.append(klayers.Conv2D(num_kernels, (3, 3), strides=(1, 1), padding="same")) | |
| layers.append(klayers.ReLU()) | |
| layers.append(klayers.Conv2D(num_kernels, (3, 3), strides=(1, 1), padding="same")) | |
| layers.append(klayers.ReLU()) | |
| # layers.append(klayers.Dropout(dropout_rate)) | |
| layers.append(klayers.MaxPooling2D(pool_size=(2, 2))) | |
| layers.append(klayers.BatchNormalization()) | |
| layers.append(klayers.Conv2D(num_kernels*2, (3, 3), strides=(1, 1), padding="same")) | |
| layers.append(klayers.ReLU()) | |
| layers.append(klayers.Conv2D(num_kernels*2, (3, 3), strides=(1, 1), padding="same")) | |
| layers.append(klayers.ReLU()) | |
| layers.append(klayers.MaxPooling2D(pool_size=(2, 2))) | |
| layers.append(klayers.BatchNormalization()) | |
| layers.append(klayers.Conv2D(num_kernels*4, (3, 3), strides=(1, 1), padding="same")) | |
| layers.append(klayers.MaxPooling2D(pool_size=(2, 2))) | |
| layers.append(klayers.Flatten()) | |
| layers.append(klayers.Dropout(dropout_rate)) | |
| layers.append(klayers.Dense(512)) | |
| layers.append(klayers.Dense(10, activation="softmax")) | |
| return tf.keras.Sequential(layers) | |
| # Set a model name | |
| model_name = "mnist" | |
| model_name += "--" + "--".join( | |
| [ | |
| "+".join(list(data_definition.input_features)), | |
| "+".join(list(data_definition.output_features)), | |
| ] | |
| ) | |
| model_name += "--onehot" | |
| saved_model_name = "mnist--image--label--onehot-default-64C3-64C3-MP2-128C3-128C3-MP2-256C3-MP2-512-10" | |
| model = cids.CIDSModel.categorical_classification( | |
| 10, | |
| data_definition, | |
| model_function, | |
| name=model_name, | |
| identifier="default", # or "best" | |
| result_dir=project.result_dir, | |
| ) | |
| model.encode_categorical = "outputs" | |
| model.metrics.append("accuracy") | |
| model.monitor = "val_accuracy" | |
| model.online_normalize = False | |
| model.data_reader.prefetch = "cache" | |
| # Load hp | |
| import json | |
| hp_path = os.path.join(model.base_model_dir, saved_model_name, "hp.json") | |
| print("hp path", hp_path) | |
| with open(hp_path, "r") as f: | |
| saved_config = json.load(f) | |
| hp = HyperParameters.from_config(saved_config) | |
| if EVAL: | |
| project.log(">> Evaluating...") | |
| project.log(">>> Metrics...") | |
| # Compute predictions | |
| # test_loss = model.eval_data( | |
| # test_samples, batch_size=4, checkpoint="last", hp=hp, submodels="generator") | |
| print("hp", hp.values) | |
| X, Y, Y_ = model.infer_data( | |
| test_samples, | |
| batch_size=4, | |
| checkpoint="last", | |
| hp=hp, | |
| ) | |
| model = model.core_model | |
| def recognize_digit(image): | |
| image = cv.resize(image["composite"][:,:,-1], (28, 28)) | |
| # For debug, find out which chanel | |
| # from PIL import Image | |
| # for i in range(images["composite"].shape[-1]): | |
| # image = images["composite"][:,:,i] | |
| # im = Image.fromarray(image) | |
| # im.save(f"c{i}.png") | |
| # for i in range(images["background"].shape[-1]): | |
| # image = images["background"][:,:,i] | |
| # im = Image.fromarray(image) | |
| # im.save(f"b{i}.png") | |
| # print(image.shape) | |
| # # image = image / 255 | |
| # # plt.imshow(image) | |
| #from PIL import Image | |
| #im = Image.fromarray(image) | |
| #im.save(f"saved.png") | |
| # image = image / 255 | |
| #print("max", image.max()) | |
| image = (image - 127.5) / 127.5 | |
| image = image.reshape((1, 28, 28)) | |
| prediction = model.predict(image) | |
| prediction = model.predict(image).tolist()[0] | |
| return {str(i): prediction[i] for i in range(10)} | |
| gr.Interface(fn=recognize_digit, | |
| inputs="sketchpad", | |
| outputs=gr.Label(num_top_classes=3), | |
| live=True, | |
| css=".footer {display:none !important}", | |
| # title="MNIST Sketchpad", | |
| description="A simple model trained on MNIST dataset using CIDS framework.\nDraw a single digit (0-9) in the center of the canvas.").launch(share=True) | |