Spaces:
Sleeping
Sleeping
| # %% | |
| import gradio as gr | |
| import tensorflow as tf | |
| import cv2 | |
| import os | |
| model_folder = 'model' | |
| destination = model_folder | |
| repo_url = "https://huggingface.co/RandomCatLover/plants_disease" | |
| if not os.path.exists(destination): | |
| import subprocess | |
| #repo_url = os.getenv("GIT_CORE") | |
| command = f'git clone {repo_url} {destination}' | |
| try: | |
| subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) | |
| print('Repository cloned successfully.') | |
| except subprocess.CalledProcessError as e: | |
| print(f'Error cloning repository: {e.output.decode()}') | |
| destination = 'explainer_tf_mobilenetv2' | |
| if not os.path.exists(destination): | |
| import subprocess | |
| repo_url = os.getenv("GIT_CORE") | |
| command = f'git clone {repo_url}' | |
| try: | |
| subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True)#, env=env) | |
| print('Repository cloned successfully.') | |
| except subprocess.CalledProcessError as e: | |
| print(f'Error cloning repository: {e.output.decode()}') | |
| from explainer_tf_mobilenetv2.explainer import explainer | |
| # %% | |
| with open(f'{model_folder}/labels.txt', 'r') as f: | |
| labels = f.read().split('\n') | |
| # model = tf.saved_model.load(f'{model_folder}/last_layer.hdf5') | |
| model = tf.keras.models.load_model(f'{model_folder}/last_layer.hdf5') | |
| #model = tf.keras.models.load_model(f'{model_folder}/MobileNetV2_last_layer.hdf5') | |
| # %% | |
| def classify_image(inp): | |
| inp = cv2.resize(inp, (224,224,)) | |
| inp = inp.reshape((-1, 224, 224, 3)) | |
| inp = tf.keras.applications.mobilenet_v2.preprocess_input(inp) | |
| prediction = model.predict(inp).flatten() | |
| print(prediction) | |
| confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} | |
| return confidences | |
| def explainer_wrapper(inp): | |
| return explainer(inp, model) | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.inputs.Image(shape=(224, 224)) | |
| with gr.Row(): | |
| classify = gr.Button("Classify") | |
| interpret = gr.Button("Interpret") | |
| with gr.Column(): | |
| label = gr.outputs.Label(num_top_classes=3) | |
| interpretation = gr.Plot(label="Interpretation") | |
| # interpretation = gr.outputs.Image(type="numpy", label="Interpretation") | |
| gr.Examples(["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"], | |
| inputs=[image],) | |
| classify.click(classify_image, image, label, queue=True) | |
| interpret.click(explainer_wrapper, image, interpretation, queue=True) | |
| demo.queue(concurrency_count=3).launch() | |
| #%% | |
| # gr.Interface(fn=classify_image, | |
| # inputs=gr.Image(shape=(224, 224)), | |
| # outputs=gr.Label(num_top_classes=3), | |
| # examples=["TomatoHealthy2.jpg", "TomatoYellowCurlVirus3.jpg", "AppleCedarRust3.jpg"]).launch() | |