Spaces:
Runtime error
Runtime error
| import io | |
| import gradio as gr | |
| import requests, validators | |
| import torch | |
| import pathlib | |
| from PIL import Image | |
| import datasets | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
| import os | |
| import IPython | |
| os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("saved_model_files") | |
| model = AutoModelForImageClassification.from_pretrained("saved_model_files") | |
| labels = ['angular_leaf_spot', 'bean_rust', 'healthy'] | |
| def classify(im): | |
| '''FUnction for classifying plant health status''' | |
| features = feature_extractor(im, return_tensors='pt') | |
| with torch.no_grad(): | |
| logits = model(**features).logits | |
| probability = torch.nn.functional.softmax(logits, dim=-1) | |
| probs = probability[0].detach().numpy() | |
| confidences = {label: float(probs[i]) for i, label in enumerate(labels)} | |
| return confidences | |
| def get_original_image(url_input): | |
| '''Get image from URL''' | |
| if validators.url(url_input): | |
| image = Image.open(requests.get(url_input, stream=True).raw) | |
| return image | |
| def detect_plant_health(url_input,image_input,webcam_input): | |
| if validators.url(url_input): | |
| image = Image.open(requests.get(url_input, stream=True).raw) | |
| elif image_input: | |
| image = image_input | |
| elif webcam_input: | |
| image = webcam_input | |
| #Make prediction | |
| label_probs = classify(image) | |
| return label_probs | |
| def set_example_image(example: list) -> dict: | |
| return gr.Image.update(value=example[0]) | |
| def set_example_url(example: list) -> dict: | |
| return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0])) | |
| title = """<h1 id="title">Plant Health Classification with ViT</h1>""" | |
| description = """ | |
| This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset. | |
| The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset. | |
| How to use the app: | |
| - Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam | |
| - The app will take a few seconds to generate a prediction with the following labels: | |
| - *angular_leaf_spot* | |
| - *bean_rust* | |
| - *healthy* | |
| - Feel free to click the image examples as well. | |
| """ | |
| urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"] | |
| images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.p*g'))] | |
| twitter_link = """ | |
| [](https://twitter.com/nickmuchi) | |
| """ | |
| css = ''' | |
| h1#title { | |
| text-align: center; | |
| } | |
| ''' | |
| demo = gr.Blocks(css=css) | |
| with demo: | |
| gr.Markdown(title) | |
| gr.HTML('<center><img src="file/images/Healthy.png" width=350px height=350px></center>') | |
| gr.Markdown(description) | |
| gr.Markdown(twitter_link) | |
| with gr.Tabs(): | |
| with gr.TabItem('Image Upload'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type='pil',shape=(450,450)) | |
| label_from_upload= gr.Label(num_top_classes=3) | |
| with gr.Row(): | |
| example_images = gr.Examples(examples=images,inputs=[img_input]) | |
| img_but = gr.Button('Classify') | |
| with gr.TabItem('Image URL'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') | |
| original_image = gr.Image(shape=(450,450)) | |
| url_input.change(get_original_image, url_input, original_image) | |
| with gr.Column(): | |
| label_from_url = gr.Label(num_top_classes=3) | |
| with gr.Row(): | |
| example_url = gr.Examples(examples=urls,inputs=[url_input]) | |
| url_but = gr.Button('Classify') | |
| with gr.TabItem('WebCam'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| web_input = gr.Image(source='webcam',type='pil',shape=(450,450),streaming=True) | |
| with gr.Column(): | |
| label_from_webcam= gr.Label(num_top_classes=3) | |
| cam_but = gr.Button('Classify') | |
| url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True) | |
| img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True) | |
| cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True) | |
| gr.Markdown("") | |
| demo.launch(debug=True,enable_queue=True) |