File size: 1,587 Bytes
05d6a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import gradio as gr
from torch.nn.functional import softmax
import torch
from transformers import ViTFeatureExtractor
from transformers import MobileViTFeatureExtractor
from transformers import MobileViTForImageClassification
from transformers import ViTForImageClassification
def predict(model_type, inp):
if model_type == "ViT":
model_name_or_path = './models/vit-base-garbage/'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(model_name_or_path)
elif model_type == "MobileViT":
model_name_or_path = './models/apple/mobilevit-small-garbage/'
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path)
model = MobileViTForImageClassification.from_pretrained(model_name_or_path)
inputs = feature_extractor(inp, return_tensors="pt")
LABELS = list(model.config.label2id.keys())
with torch.no_grad():
logits = model(**inputs)
print(logits[0])
probability = torch.nn.functional.softmax(logits[0], dim=-1)
confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)}
# print(confidences)
return confidences
demo = gr.Interface(fn=predict,
inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")],
outputs=gr.outputs.Label(num_top_classes=3),
examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]],
)
demo.launch() |