import torch import timm import gradio as gr import cv2 import json import numpy as np from torchvision import transforms from huggingface_hub import hf_hub_download # ---------------- CONFIG ---------------- # MODEL_REPO = "vijeshkp/vit_deit_finetune" MODEL_FILE = "pytorch_model.bin" LABEL_FILE = "labels.json" IMG_SIZE = 224 DEVICE = "cpu" # ---------------- LOAD LABELS ---------------- # labels_path = hf_hub_download(MODEL_REPO, LABEL_FILE) with open(labels_path, "r") as f: labels = json.load(f) class_names = [labels[str(i)] for i in range(len(labels))] # ---------------- LOAD MODEL ---------------- # model_path = hf_hub_download(MODEL_REPO, MODEL_FILE) model = timm.create_model( "deit_base_patch16_224", pretrained=False, num_classes=len(class_names) ) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() # ---------------- TRANSFORM ---------------- # transform = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # ---------------- PREDICTION FUNCTION ---------------- # def predict(image): img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) tensor = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1)[0] pred_idx = torch.argmax(probs).item() return { class_names[i]: float(probs[i]) for i in range(len(class_names)) } # ---------------- GRADIO UI ---------------- # demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Image"), outputs=gr.Label(num_top_classes=2, label="Prediction"), title="DeiT Sitting vs Standing Classifier", description="Upload a human image to classify posture using a fine-tuned DeiT model." ) demo.launch()