vijeshkp commited on
Commit
93aacbf
·
verified ·
1 Parent(s): ada209e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import timm
3
+ import gradio as gr
4
+ import cv2
5
+ import json
6
+ import numpy as np
7
+ from torchvision import transforms
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # ---------------- CONFIG ---------------- #
11
+ MODEL_REPO = "vijeshkp/vit_deit_finetune"
12
+ MODEL_FILE = "pytorch_model.bin"
13
+ LABEL_FILE = "labels.json"
14
+ IMG_SIZE = 224
15
+ DEVICE = "cpu"
16
+
17
+ # ---------------- LOAD LABELS ---------------- #
18
+ labels_path = hf_hub_download(MODEL_REPO, LABEL_FILE)
19
+ with open(labels_path, "r") as f:
20
+ labels = json.load(f)
21
+ class_names = [labels[str(i)] for i in range(len(labels))]
22
+
23
+ # ---------------- LOAD MODEL ---------------- #
24
+ model_path = hf_hub_download(MODEL_REPO, MODEL_FILE)
25
+
26
+ model = timm.create_model(
27
+ "deit_base_patch16_224",
28
+ pretrained=False,
29
+ num_classes=len(class_names)
30
+ )
31
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
32
+ model.eval()
33
+
34
+ # ---------------- TRANSFORM ---------------- #
35
+ transform = transforms.Compose([
36
+ transforms.ToPILImage(),
37
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
40
+ ])
41
+
42
+ # ---------------- PREDICTION FUNCTION ---------------- #
43
+ def predict(image):
44
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
45
+ tensor = transform(img).unsqueeze(0)
46
+
47
+ with torch.no_grad():
48
+ logits = model(tensor)
49
+ probs = torch.softmax(logits, dim=1)[0]
50
+
51
+ pred_idx = torch.argmax(probs).item()
52
+ return {
53
+ class_names[i]: float(probs[i])
54
+ for i in range(len(class_names))
55
+ }
56
+
57
+ # ---------------- GRADIO UI ---------------- #
58
+ demo = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Image(type="numpy", label="Upload Image"),
61
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
62
+ title="DeiT Sitting vs Standing Classifier",
63
+ description="Upload a human image to classify posture using a fine-tuned DeiT model."
64
+ )
65
+
66
+ demo.launch()