Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ViTForImageClassification | |
| from torchvision import transforms | |
| import os | |
| import numpy as np | |
| import json | |
| vocabulary = json.load(open("word2idx.json", "r")) | |
| vocabulary = {v: k for (k, v) in vocabulary.items()} | |
| model = ViTForImageClassification.from_pretrained("Inf009/food1024_vit_focal_mixup", problem_type="multi_label_classification", num_labels=len(vocabulary)) | |
| test_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize((256, 256)), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| def multi_label_predict(img, threshold=0.5): | |
| img_transformed = test_transforms(img) | |
| outputs = model(img_transformed.unsqueeze(0)).logits.squeeze(0).sigmoid().detach().numpy() | |
| indices = np.where(outputs > threshold)[0] | |
| indices = sorted(indices, key=lambda x: outputs[x], reverse=True) | |
| predict_tags = [vocabulary[idx] for idx in indices] | |
| return predict_tags | |
| demo_image_path = "images" | |
| images = [f for f in os.listdir(demo_image_path) if f.endswith(".jpg")][:3] | |
| images = [os.path.join(demo_image_path, file) for file in images] | |
| examples = [[image, 0.5] for image in images] | |
| iface = gr.Interface(fn=multi_label_predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(default=0.5)], | |
| examples=examples, outputs="text") | |
| iface.launch() |