from transformers import AutoFeatureExtractor, AutoModelForImageClassification from PIL import Image import torch import gradio as gr # Category classifier (e.g. Food / Fruit / Vegetable / Rice) cat_ex = AutoFeatureExtractor.from_pretrained("Kaludi/food-category-classification-v2.0") cat_model = AutoModelForImageClassification.from_pretrained("Kaludi/food-category-classification-v2.0") # Updated Fruit species classifier fruit_ex = AutoFeatureExtractor.from_pretrained("walzsil1/vit-base-fruits-360") fruit_model = AutoModelForImageClassification.from_pretrained("walzsil1/vit-base-fruits-360") def classify(img): inp = cat_ex(images=img, return_tensors="pt") logits = cat_model(**inp).logits probs = torch.softmax(logits, dim=1) idx = probs.argmax().item() category = cat_model.config.id2label[idx] out = {f"category: {category}": float(probs[0, idx])} if category.lower() == "fruit": inp2 = fruit_ex(images=img, return_tensors="pt") logits2 = fruit_model(**inp2).logits probs2 = torch.softmax(logits2, dim=1) idx2 = probs2.argmax().item() species = fruit_model.config.id2label[idx2] out[f"fruit: {species}"] = float(probs2[0, idx2]) return out demo = gr.Interface( fn=classify, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Grocery Category + Fruit Species Recognizer", description="Classifies food category, then if it's fruit, fine-grained species." ) demo.launch()