Spaces:
Runtime error
Runtime error
Create new file
Browse files
app.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import requests, validators
|
| 5 |
+
import torch
|
| 6 |
+
import pathlib
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import datasets
|
| 9 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
| 14 |
+
|
| 15 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
|
| 16 |
+
model = AutoModelForImageClassification.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
|
| 17 |
+
|
| 18 |
+
labels = ['angular_leaf_spot', 'bean_rust', 'healthy']
|
| 19 |
+
|
| 20 |
+
def classify(im):
|
| 21 |
+
'''FUnction for classifying plant health status'''
|
| 22 |
+
|
| 23 |
+
features = feature_extractor(im, return_tensors='pt')
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
logits = model(**features).logits
|
| 26 |
+
probability = torch.nn.functional.softmax(logits, dim=-1)
|
| 27 |
+
probs = probability[0].detach().numpy()
|
| 28 |
+
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
|
| 29 |
+
|
| 30 |
+
return confidences
|
| 31 |
+
|
| 32 |
+
def get_original_image(url_input):
|
| 33 |
+
'''Get image from URL'''
|
| 34 |
+
if validators.url(url_input):
|
| 35 |
+
|
| 36 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
| 37 |
+
|
| 38 |
+
return image
|
| 39 |
+
|
| 40 |
+
def detect_plant_health(url_input,image_input,webcam_input):
|
| 41 |
+
|
| 42 |
+
if validators.url(url_input):
|
| 43 |
+
image = Image.open(requests.get(url_input, stream=True).raw)
|
| 44 |
+
|
| 45 |
+
elif image_input:
|
| 46 |
+
image = image_input
|
| 47 |
+
|
| 48 |
+
elif webcam_input:
|
| 49 |
+
image = webcam_input
|
| 50 |
+
|
| 51 |
+
#Make prediction
|
| 52 |
+
label_probs = classify(image)
|
| 53 |
+
|
| 54 |
+
return label_probs
|
| 55 |
+
|
| 56 |
+
def set_example_image(example: list) -> dict:
|
| 57 |
+
return gr.Image.update(value=example[0])
|
| 58 |
+
|
| 59 |
+
def set_example_url(example: list) -> dict:
|
| 60 |
+
return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0]))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
title = """<h1 id="title">Plant Health Classification with ViT</h1>"""
|
| 64 |
+
|
| 65 |
+
description = """
|
| 66 |
+
This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset.
|
| 67 |
+
The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset.
|
| 68 |
+
|
| 69 |
+
How to use the app:
|
| 70 |
+
- Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam
|
| 71 |
+
- The app will take a few seconds to generate a prediction with the following labels:
|
| 72 |
+
- *'angular_leaf_spot'*
|
| 73 |
+
- *'bean_rust'*
|
| 74 |
+
- *'healthy'*
|
| 75 |
+
- Feel free to click the image examples as well.
|
| 76 |
+
"""
|
| 77 |
+
urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"]
|
| 78 |
+
images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.j*g'))]
|
| 79 |
+
|
| 80 |
+
twitter_link = """
|
| 81 |
+
[](https://twitter.com/nickmuchi)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
css = '''
|
| 85 |
+
h1#title {
|
| 86 |
+
text-align: center;
|
| 87 |
+
}
|
| 88 |
+
'''
|
| 89 |
+
demo = gr.Blocks(css=css)
|
| 90 |
+
|
| 91 |
+
with demo:
|
| 92 |
+
gr.Markdown(title)
|
| 93 |
+
gr.Markdown(description)
|
| 94 |
+
gr.Markdown(twitter_link)
|
| 95 |
+
|
| 96 |
+
with gr.Tabs():
|
| 97 |
+
with gr.TabItem('Image Upload'):
|
| 98 |
+
with gr.Row():
|
| 99 |
+
with gr.Column():
|
| 100 |
+
img_input = gr.Image(type='pil',shape=(750,750))
|
| 101 |
+
label_from_upload= gr.Label()
|
| 102 |
+
|
| 103 |
+
with gr.Row():
|
| 104 |
+
example_images = gr.Examples(examples=images,inputs=[img_input])
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
img_but = gr.Button('Classify')
|
| 108 |
+
|
| 109 |
+
with gr.TabItem('Image URL'):
|
| 110 |
+
with gr.Row():
|
| 111 |
+
with gr.Column():
|
| 112 |
+
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
|
| 113 |
+
original_image = gr.Image(shape=(750,750))
|
| 114 |
+
url_input.change(get_original_image, url_input, original_image)
|
| 115 |
+
with gr.Column():
|
| 116 |
+
label_from_url = gr.Label()
|
| 117 |
+
|
| 118 |
+
with gr.Row():
|
| 119 |
+
example_url = gr.Examples(examples=urls,inputs=[url_input])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
url_but = gr.Button('Classify')
|
| 123 |
+
|
| 124 |
+
with gr.TabItem('WebCam'):
|
| 125 |
+
with gr.Row():
|
| 126 |
+
with gr.Column():
|
| 127 |
+
web_input = gr.Image(source='webcam',type='pil',shape=(750,750),streaming=True)
|
| 128 |
+
with gr.Column():
|
| 129 |
+
label_from_webcam= gr.Label()
|
| 130 |
+
|
| 131 |
+
cam_but = gr.Button('Classify')
|
| 132 |
+
|
| 133 |
+
url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True)
|
| 134 |
+
img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True)
|
| 135 |
+
cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True)
|
| 136 |
+
|
| 137 |
+
gr.Markdown("")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
demo.launch(debug=True,enable_queue=True)
|