Spaces:
Runtime error
Runtime error
| import torch | |
| import timm | |
| import gradio as gr | |
| import cv2 | |
| import json | |
| import numpy as np | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| # ---------------- CONFIG ---------------- # | |
| MODEL_REPO = "vijeshkp/vit_deit_finetune" | |
| MODEL_FILE = "pytorch_model.bin" | |
| LABEL_FILE = "labels.json" | |
| IMG_SIZE = 224 | |
| DEVICE = "cpu" | |
| # ---------------- LOAD LABELS ---------------- # | |
| labels_path = hf_hub_download(MODEL_REPO, LABEL_FILE) | |
| with open(labels_path, "r") as f: | |
| labels = json.load(f) | |
| class_names = [labels[str(i)] for i in range(len(labels))] | |
| # ---------------- LOAD MODEL ---------------- # | |
| model_path = hf_hub_download(MODEL_REPO, MODEL_FILE) | |
| model = timm.create_model( | |
| "deit_base_patch16_224", | |
| pretrained=False, | |
| num_classes=len(class_names) | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| # ---------------- TRANSFORM ---------------- # | |
| transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| # ---------------- PREDICTION FUNCTION ---------------- # | |
| def predict(image): | |
| img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| tensor = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| pred_idx = torch.argmax(probs).item() | |
| return { | |
| class_names[i]: float(probs[i]) | |
| for i in range(len(class_names)) | |
| } | |
| # ---------------- GRADIO UI ---------------- # | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy", label="Upload Image"), | |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), | |
| title="DeiT Sitting vs Standing Classifier", | |
| description="Upload a human image to classify posture using a fine-tuned DeiT model." | |
| ) | |
| demo.launch() | |