LuxExistentia commited on
Commit
8185d89
·
1 Parent(s): 359cd8e

Check point

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .ipynb_checkpoints
2
+ __pycache__
3
+ .pyc
4
+ .DS_Store
dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from PIL import Image
3
+ import os
4
+ import re
5
+ import torch
6
+
7
+ class AgeDataset(Dataset):
8
+ def __init__(self, target_dir, transform=None):
9
+ input_file_format = ('png', 'jpg', 'jpeg')
10
+
11
+ self.target_dir = target_dir
12
+ self.paths = [f for f in os.listdir(self.target_dir) if f.lower().endswith(input_file_format)]
13
+ self.transform = transform
14
+
15
+ def load_img(self, idx):
16
+ img_path = os.path.join(self.target_dir, self.paths[idx])
17
+ target_value = re.search(r"(\d+)_", self.paths[idx])
18
+ target_value = int(target_value.group(1))
19
+
20
+ return Image.open(img_path).convert("RGB"), torch.tensor(target_value, dtype=torch.float).unsqueeze(dim=0)
21
+
22
+ def __len__(self):
23
+ return len(self.paths)
24
+
25
+ def __getitem__(self, idx):
26
+ img, target_value = self.load_img(idx)
27
+
28
+ if self.transform:
29
+ img = self.transform(img)
30
+
31
+ return img, target_value
pretrained_weight/vit_medium_patch16_clip_224.tinyclip_yfcc15m(trainable 0.00) (eval Score 0.9067, Loss 29.465482).pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:395d85bd8ff805c07120fd64407dbc47e1352d3ddc6876a63b309ce9072a0100
3
+ size 154448566
test_gradio.ipynb ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9782f4d6-2dc8-44a7-bb91-6636a71415e8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import gradio as gr\n",
11
+ "import time\n",
12
+ "from PIL import Image\n",
13
+ "import torch\n",
14
+ "from torch import nn\n",
15
+ "import timm\n",
16
+ "from custom_torch_module import setup_utils"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 2,
22
+ "id": "7cfc4e19-e2c5-41d0-b105-9e4a661001c4",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "title = \"Age Prediction model\"\n",
27
+ "description = \"ViT(medium clip) based model. transfer trained with custom dataset\"\n",
28
+ "article = \"Through bunch of fine tuning and experiments. REMEMBER! This model can be wrong.\""
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "504cd630-6e4e-43b8-b444-d49775bbcf5d",
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "data": {
39
+ "text/plain": [
40
+ "<All keys matched successfully>"
41
+ ]
42
+ },
43
+ "execution_count": 3,
44
+ "metadata": {},
45
+ "output_type": "execute_result"
46
+ }
47
+ ],
48
+ "source": [
49
+ "MODEL_NAME = \"vit_medium_patch16_clip_224.tinyclip_yfcc15m\"\n",
50
+ "FILE_NAME = \"pretrained_weight/vit_medium_patch16_clip_224.tinyclip_yfcc15m(trainable 0.00) (eval Score 0.9067, Loss 29.465482).pth\"\n",
51
+ "DEVICE = \"cpu\"\n",
52
+ "\n",
53
+ "torch.set_default_device(DEVICE)\n",
54
+ "\n",
55
+ "model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=0, drop_rate=0.7)\n",
56
+ "\n",
57
+ "model_classifier = nn.Sequential(nn.Linear(512, 512),\n",
58
+ " nn.BatchNorm1d(512),\n",
59
+ " nn.GELU(),\n",
60
+ " nn.Linear(512, 1))\n",
61
+ "\n",
62
+ "model = nn.Sequential(model, model_classifier)\n",
63
+ "\n",
64
+ "test_transform = setup_utils.build_transform(img_size=224, is_data_aug=False)\n",
65
+ "\n",
66
+ "model.load_state_dict(state_dict=torch.load(FILE_NAME, weights_only=True))"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 4,
72
+ "id": "1dba4088-961a-4f14-8d1e-2f9044a652e3",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "def predict(img):\n",
77
+ " start_time = time.time()\n",
78
+ " model.eval()\n",
79
+ " with torch.inference_mode():\n",
80
+ " img = test_transform(img).unsqueeze(dim=0).to(DEVICE)\n",
81
+ " pred_age = model(img).item()\n",
82
+ " \n",
83
+ " end_time = time.time()\n",
84
+ " \n",
85
+ " elapsed_time = end_time - start_time\n",
86
+ " fps = 1 / elapsed_time\n",
87
+ " return pred_age, fps\n",
88
+ "\n",
89
+ "# img = Image.open(img_path[0]).convert(\"RGB\")\n",
90
+ "# pred_label_and_probs, elapsed_time = predict(img)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "63313e6f-f86f-4d42-a8e1-986fa394e51a",
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "name": "stdout",
101
+ "output_type": "stream",
102
+ "text": [
103
+ "Running on local URL: http://127.0.0.1:7860\n",
104
+ "Running on public URL: https://3b12d13cbee982d501.gradio.live\n",
105
+ "\n",
106
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
107
+ ]
108
+ },
109
+ {
110
+ "data": {
111
+ "text/html": [
112
+ "<div><iframe src=\"https://3b12d13cbee982d501.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
113
+ ],
114
+ "text/plain": [
115
+ "<IPython.core.display.HTML object>"
116
+ ]
117
+ },
118
+ "metadata": {},
119
+ "output_type": "display_data"
120
+ }
121
+ ],
122
+ "source": [
123
+ "# Create the Gradio demo\n",
124
+ "demo = gr.Interface(fn=predict, \n",
125
+ " inputs=gr.Image(type=\"pil\"),\n",
126
+ " outputs=[gr.Number(label=\"Age Prediction\"),\n",
127
+ " gr.Number(label=\"Prediction speed (fps)\")], \n",
128
+ " title=title,\n",
129
+ " description=description,\n",
130
+ " article=article)\n",
131
+ "\n",
132
+ "# Launch the demo!\n",
133
+ "demo.launch(debug=True, # print errors locally?\n",
134
+ " share=True) # generate a publically shareable URL?"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "9a17b040-b662-45ca-92e7-517c52bc5950",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": []
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "id": "f58e0cc1-22b1-4dcc-935a-4d9c98961c9e",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": []
152
+ }
153
+ ],
154
+ "metadata": {
155
+ "kernelspec": {
156
+ "display_name": "Python 3 (ipykernel)",
157
+ "language": "python",
158
+ "name": "python3"
159
+ },
160
+ "language_info": {
161
+ "codemirror_mode": {
162
+ "name": "ipython",
163
+ "version": 3
164
+ },
165
+ "file_extension": ".py",
166
+ "mimetype": "text/x-python",
167
+ "name": "python",
168
+ "nbconvert_exporter": "python",
169
+ "pygments_lexer": "ipython3",
170
+ "version": "3.12.3"
171
+ }
172
+ },
173
+ "nbformat": 4,
174
+ "nbformat_minor": 5
175
+ }