vijeshkp's picture
Create app.py
93aacbf verified
import torch
import timm
import gradio as gr
import cv2
import json
import numpy as np
from torchvision import transforms
from huggingface_hub import hf_hub_download
# ---------------- CONFIG ---------------- #
MODEL_REPO = "vijeshkp/vit_deit_finetune"
MODEL_FILE = "pytorch_model.bin"
LABEL_FILE = "labels.json"
IMG_SIZE = 224
DEVICE = "cpu"
# ---------------- LOAD LABELS ---------------- #
labels_path = hf_hub_download(MODEL_REPO, LABEL_FILE)
with open(labels_path, "r") as f:
labels = json.load(f)
class_names = [labels[str(i)] for i in range(len(labels))]
# ---------------- LOAD MODEL ---------------- #
model_path = hf_hub_download(MODEL_REPO, MODEL_FILE)
model = timm.create_model(
"deit_base_patch16_224",
pretrained=False,
num_classes=len(class_names)
)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# ---------------- TRANSFORM ---------------- #
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# ---------------- PREDICTION FUNCTION ---------------- #
def predict(image):
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1)[0]
pred_idx = torch.argmax(probs).item()
return {
class_names[i]: float(probs[i])
for i in range(len(class_names))
}
# ---------------- GRADIO UI ---------------- #
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="DeiT Sitting vs Standing Classifier",
description="Upload a human image to classify posture using a fine-tuned DeiT model."
)
demo.launch()