|
|
import gradio as gr |
|
|
from torch.nn.functional import softmax |
|
|
import torch |
|
|
from transformers import ViTFeatureExtractor |
|
|
from transformers import MobileViTFeatureExtractor |
|
|
from transformers import MobileViTForImageClassification |
|
|
from transformers import ViTForImageClassification |
|
|
|
|
|
|
|
|
def predict(model_type, inp): |
|
|
|
|
|
if model_type == "ViT": |
|
|
model_name_or_path = './models/vit-base-garbage/' |
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path) |
|
|
model = ViTForImageClassification.from_pretrained(model_name_or_path) |
|
|
elif model_type == "MobileViT": |
|
|
model_name_or_path = './models/apple/mobilevit-small-garbage/' |
|
|
feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path) |
|
|
model = MobileViTForImageClassification.from_pretrained(model_name_or_path) |
|
|
inputs = feature_extractor(inp, return_tensors="pt") |
|
|
LABELS = list(model.config.label2id.keys()) |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs) |
|
|
print(logits[0]) |
|
|
probability = torch.nn.functional.softmax(logits[0], dim=-1) |
|
|
|
|
|
confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)} |
|
|
|
|
|
return confidences |
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
|
inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")], |
|
|
outputs=gr.outputs.Label(num_top_classes=3), |
|
|
examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]], |
|
|
) |
|
|
|
|
|
demo.launch() |