LuxExistentia commited on
Commit
a783b6e
·
1 Parent(s): 8185d89

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+ from PIL import Image
4
+ import torch
5
+ from torch import nn
6
+ import timm
7
+ from custom_torch_module import setup_utils
8
+
9
+
10
+
11
+ title = "Age Prediction model"
12
+ description = "ViT(medium clip) based model. transfer trained with custom dataset"
13
+ article = "Through bunch of fine tuning and experiments. REMEMBER! This model can be wrong."
14
+
15
+ MODEL_NAME = "vit_medium_patch16_clip_224.tinyclip_yfcc15m"
16
+ FILE_NAME = "pretrained_weight/vit_medium_patch16_clip_224.tinyclip_yfcc15m(trainable 0.00) (eval Score 0.9067, Loss 29.465482).pth"
17
+ DEVICE = "cpu"
18
+
19
+
20
+
21
+
22
+ torch.set_default_device(DEVICE)
23
+
24
+ model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=0, drop_rate=0.7)
25
+ model_classifier = nn.Sequential(nn.Linear(512, 512),
26
+ nn.BatchNorm1d(512),
27
+ nn.GELU(),
28
+ nn.Linear(512, 1))
29
+ model = nn.Sequential(model, model_classifier)
30
+
31
+ test_transform = setup_utils.build_transform(img_size=224, is_data_aug=False)
32
+ model.load_state_dict(state_dict=torch.load(FILE_NAME, weights_only=True))
33
+
34
+
35
+
36
+
37
+ def predict(img):
38
+ start_time = time.time()
39
+ model.eval()
40
+ with torch.inference_mode():
41
+ img = test_transform(img).unsqueeze(dim=0).to(DEVICE)
42
+ pred_age = model(img).item()
43
+
44
+ end_time = time.time()
45
+
46
+ elapsed_time = end_time - start_time
47
+ fps = 1 / elapsed_time
48
+ return pred_age, fps
49
+
50
+
51
+
52
+ demo = gr.Interface(fn=predict,
53
+ inputs=gr.Image(type="pil"),
54
+ outputs=[gr.Number(label="Age Prediction"),
55
+ gr.Number(label="Prediction speed (fps)")],
56
+ title=title,
57
+ description=description,
58
+ article=article)
59
+ demo.launch()