food_tagger_vit / app.py
baixintech_zhangyiming_prod
fix file open bug
4e2b453
import gradio as gr
from transformers import ViTForImageClassification
from torchvision import transforms
import os
import numpy as np
import json
vocabulary = json.load(open("word2idx.json", "r"))
vocabulary = {v: k for (k, v) in vocabulary.items()}
model = ViTForImageClassification.from_pretrained("Inf009/food1024_vit_focal_mixup", problem_type="multi_label_classification", num_labels=len(vocabulary))
test_transforms = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)
def multi_label_predict(img, threshold=0.5):
img_transformed = test_transforms(img)
outputs = model(img_transformed.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy()
indices = np.where(outputs > threshold)[0]
indices = sorted(indices, key=lambda x: outputs[x], reverse=True)
predict_tags = [vocabulary[idx] for idx in indices]
return predict_tags
demo_image_path = "images"
images = [f for f in os.listdir(demo_image_path) if f.endswith(".jpg")][:3]
images = [os.path.join(demo_image_path, file) for file in images]
examples = [[image, 0.5] for image in images]
iface = gr.Interface(fn=multi_label_predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(default=0.5)],
examples=examples, outputs="text")
iface.launch()