File size: 1,657 Bytes
6eb8d8f
5c4be61
 
486fc30
 
5c4be61
 
486fc30
 
5c4be61
 
 
486fc30
5c4be61
 
 
 
 
 
 
 
 
 
 
 
 
486fc30
 
 
5c4be61
 
486fc30
5c4be61
 
 
 
 
486fc30
 
 
 
5c4be61
 
486fc30
57b9e42
 
76a9f88
5c4be61
486fc30
 
5c4be61
05cb4f3
486fc30
57b9e42
 
9231b08
486fc30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from PIL import Image, ImageOps
from transformers import CLIPProcessor, CLIPModel
import torch


MODEL_ID = "EduFalcao/CropVision-CLIP"


processor = CLIPProcessor.from_pretrained(MODEL_ID)
model     = CLIPModel.from_pretrained(MODEL_ID)


HF_LABELS = [
    "Grape leaf with Black rot",
    "Grape leaf with Esca (Black Measles)",
    "Grape leaf with Leaf blight (Isariopsis Leaf Spot)",
    "Healthy Grape leaf"
]
MAP = {
    "Grape leaf with Black rot": "Black Rot",
    "Grape leaf with Esca (Black Measles)": "ESCA",
    "Grape leaf with Leaf blight (Isariopsis Leaf Spot)": "Leaf Blight",
    "Healthy Grape leaf": "Healthy"
}

def predict(image: Image.Image):

    img = ImageOps.exif_transpose(image).convert("RGB")
    img = img.resize((224,224))


    inputs = processor(text=HF_LABELS, images=img, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    probs   = outputs.logits_per_image.softmax(dim=1)[0].tolist()


    mapping = { MAP[HF_LABELS[i]]: probs[i] for i in range(len(probs)) }
    best    = max(mapping, key=mapping.get)
    prob_lines = "\n".join(f"{cls}: {mapping[cls]:.2f}"
                           for cls in ["Healthy","Leaf Blight","Black Rot","ESCA"])
    return best, prob_lines


demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Carrega uma folha"),
    outputs=[
        gr.Textbox(label="Classe prevista"),
        gr.Textbox(label="Probabilidades entre Classes")
    ],
    title="Modelo CropVision",
    description="Neste modelo vamos classificar folhas de vinhas em Healthy, Leaf Blight, Black Rot ou ESCA"
)

if __name__ == "__main__":
    demo.launch()