Spaces:
Runtime error
Runtime error
File size: 1,888 Bytes
93aacbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import timm
import gradio as gr
import cv2
import json
import numpy as np
from torchvision import transforms
from huggingface_hub import hf_hub_download
# ---------------- CONFIG ---------------- #
MODEL_REPO = "vijeshkp/vit_deit_finetune"
MODEL_FILE = "pytorch_model.bin"
LABEL_FILE = "labels.json"
IMG_SIZE = 224
DEVICE = "cpu"
# ---------------- LOAD LABELS ---------------- #
labels_path = hf_hub_download(MODEL_REPO, LABEL_FILE)
with open(labels_path, "r") as f:
labels = json.load(f)
class_names = [labels[str(i)] for i in range(len(labels))]
# ---------------- LOAD MODEL ---------------- #
model_path = hf_hub_download(MODEL_REPO, MODEL_FILE)
model = timm.create_model(
"deit_base_patch16_224",
pretrained=False,
num_classes=len(class_names)
)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# ---------------- TRANSFORM ---------------- #
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# ---------------- PREDICTION FUNCTION ---------------- #
def predict(image):
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)[0]
pred_idx = torch.argmax(probs).item()
return {
class_names[i]: float(probs[i])
for i in range(len(class_names))
}
# ---------------- GRADIO UI ---------------- #
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="DeiT Sitting vs Standing Classifier",
description="Upload a human image to classify posture using a fine-tuned DeiT model."
)
demo.launch()
|