Spaces:
Build error
Build error
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from custom_model import ImageClassifier | |
| from resnet_model import ResNetClassifier | |
| from vgg16_model import VGG16Classifier | |
| from inception_v3_model import InceptionV3Classifier | |
| from mobilevet_v2 import MobileNetClassifier | |
| import os | |
| CLASS_NAMES =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck'] | |
| # models | |
| custom_model = ImageClassifier() | |
| custom_model.load_model("image_classifier_model.h5") | |
| resnet_model = ResNetClassifier() | |
| vgg16_model = VGG16Classifier() | |
| inceptionV3_model = InceptionV3Classifier() | |
| mobilenet_model = MobileNetClassifier() | |
| def make_prediction(image, model_type="CNN (Custom)"): | |
| if "CNN (Custom)" == model_type: | |
| top_classes, top_probs = custom_model.classify_image(image, top_k=3) | |
| return {CLASS_NAMES[cls_id]:str(prob) for cls_id, prob in zip(top_classes, top_probs)} | |
| elif "ResNet50" == model_type: | |
| predictions = resnet_model.classify_image(image) | |
| return {class_name:str(prob) for _, class_name, prob in predictions} | |
| elif "VGG16" == model_type: | |
| predictions = vgg16_model.classify_image(image) | |
| return {class_name:str(prob) for _, class_name, prob in predictions} | |
| elif "Inception v3" == model_type: | |
| predictions = inceptionV3_model.classify_image(image) | |
| return {class_name:str(prob) for _, class_name, prob in predictions} | |
| elif "Mobile Net v2" == model_type: | |
| predictions = mobilenet_model.classify_image(image) | |
| return {class_name:str(prob) for _, class_name, prob in predictions} | |
| else: | |
| return {"Select a model to classify image"} | |
| def train_model(epochs, batch_size, validation_split): | |
| print("Training model") | |
| # Create an instance of the ImageClassifier | |
| classifier = ImageClassifier() | |
| # Load the dataset | |
| (x_train, y_train), (x_test, y_test) = classifier.load_dataset() | |
| # Build and train the model | |
| classifier.build_model(x_train) | |
| classifier.train_model(x_train, y_train, batch_size=int(batch_size), epochs=int(epochs), validation_split=float(validation_split)) | |
| # Evaluate the model | |
| classifier.evaluate_model(x_test, y_test) | |
| # Save the trained model | |
| print("Saving model ...") | |
| classifier.save_model("image_classifier_model.h5") | |
| custom_model = classifier | |
| def update_train_param_display(model_type): | |
| if "CNN (Custom)" == model_type: | |
| return [gr.update(visible=True), gr.update(visible=False)] | |
| return [gr.update(visible=False), gr.update(visible=True)] | |
| if __name__ == "__main__": | |
| # gradio gui app | |
| with gr.Blocks() as my_app: | |
| gr.Markdown("<h1><center>Image Classification using TensorFlow</center></h1>") | |
| gr.Markdown("<h3><center>This model classifies image using different models.</center></h3>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image() | |
| model_type = gr.Dropdown( | |
| ["CNN (Custom)", | |
| "ResNet50", | |
| "VGG16", | |
| "Inception v3", | |
| "Mobile Net v2"], | |
| label="Model Type", value="CNN (Custom)", | |
| info="Select the inference model before running predictions!") | |
| with gr.Column() as train_col: | |
| gr.Markdown("Train Parameters") | |
| with gr.Row(): | |
| epochs_inp = gr.Textbox(label="Epochs", value="10") | |
| validation_split = gr.Textbox(label="Validation Split", value="0.1") | |
| with gr.Row(): | |
| batch_size = gr.Textbox(label="Batch Size", value="64") | |
| with gr.Row(): | |
| train_btn = gr.Button(value="Train") | |
| predict_btn_1 = gr.Button(value="Predict") | |
| with gr.Column(visible=False) as no_train_col: | |
| predict_btn_2 = gr.Button(value="Predict") | |
| with gr.Column(scale=1): | |
| output_label = gr.Label() | |
| gr.Markdown("## Sample Images") | |
| gr.Examples( | |
| examples=[os.path.join(os.path.dirname(__file__), "assets/dog_2.jpg"), | |
| os.path.join(os.path.dirname(__file__), "assets/truck.jpg"), | |
| os.path.join(os.path.dirname(__file__), "assets/car.jpg"), | |
| os.path.join(os.path.dirname(__file__), "assets/car_32x32.jpg") | |
| ], | |
| inputs=img_input, | |
| outputs=output_label, | |
| fn=make_prediction, | |
| cache_examples=True, | |
| ) | |
| # app logic | |
| predict_btn_1.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) | |
| predict_btn_2.click(make_prediction, inputs=[img_input, model_type], outputs=[output_label]) | |
| model_type.change(update_train_param_display, inputs=model_type, outputs=[train_col, no_train_col]) | |
| train_btn.click(train_model, inputs=[epochs_inp, batch_size, validation_split], outputs=[]) | |
| my_app.queue(concurrency_count=5, max_size=20).launch(debug=True) |