import gradio as gr from transformers import ViTForImageClassification from torchvision import transforms import os import numpy as np import json vocabulary = json.load(open("word2idx.json", "r")) vocabulary = {v: k for (k, v) in vocabulary.items()} model = ViTForImageClassification.from_pretrained("Inf009/food1024_vit_focal_mixup", problem_type="multi_label_classification", num_labels=len(vocabulary)) test_transforms = transforms.Compose( [ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), ] ) def multi_label_predict(img, threshold=0.5): img_transformed = test_transforms(img) outputs = model(img_transformed.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy() indices = np.where(outputs > threshold)[0] indices = sorted(indices, key=lambda x: outputs[x], reverse=True) predict_tags = [vocabulary[idx] for idx in indices] return predict_tags demo_image_path = "images" images = [f for f in os.listdir(demo_image_path) if f.endswith(".jpg")][:3] images = [os.path.join(demo_image_path, file) for file in images] examples = [[image, 0.5] for image in images] iface = gr.Interface(fn=multi_label_predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(default=0.5)], examples=examples, outputs="text") iface.launch()