kapil commited on
Commit
9874885
·
0 Parent(s):

Initial commit

Browse files
Files changed (16) hide show
  1. .gitattributes +5 -0
  2. .gitignore +27 -0
  3. Cargo.lock +0 -0
  4. Cargo.toml +15 -0
  5. dataset/annotate.py +410 -0
  6. dataset/labels.json +0 -0
  7. dataset/labels.pkl +3 -0
  8. src/data.rs +118 -0
  9. src/inference.rs +49 -0
  10. src/loss.rs +44 -0
  11. src/main.rs +88 -0
  12. src/model.rs +88 -0
  13. src/scoring.rs +89 -0
  14. src/server.rs +185 -0
  15. src/train.rs +106 -0
  16. static/index.html +350 -0
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ *.pkl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rust
2
+ target/
3
+ debug/
4
+ release/
5
+
6
+ # IDEs
7
+ .vscode/
8
+ .idea/
9
+ *.swp
10
+ *.swo
11
+ *~
12
+ .DS_Store
13
+
14
+ # Data & Model Weights
15
+ model_weights.bin
16
+ model_weights/
17
+ dataset/images/
18
+ dataset/cropped_images/
19
+ dataset/800/
20
+ dataset/__pycache__/
21
+
22
+ # Logs & Temp
23
+ *.log
24
+ logs/
25
+ tmp/
26
+ temp/
27
+ *.tmp
Cargo.lock ADDED
The diff for this file is too large to render. See raw diff
 
Cargo.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "rust_auto_score_engine"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+
6
+ [dependencies]
7
+ burn = { version = "0.16.0", features = ["train", "wgpu"] }
8
+ serde = { version = "1.0", features = ["derive"] }
9
+ serde_json = "1.0"
10
+ image = "0.25"
11
+ ndarray = "0.16"
12
+ axum = { version = "0.7", features = ["multipart"] }
13
+ tower-http = { version = "0.5", features = ["fs", "cors"] }
14
+ tokio = { version = "1.0", features = ["full"] }
15
+ base64 = "0.22"
dataset/annotate.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import cv2
4
+ import pandas as pd
5
+ import numpy as np
6
+ from yacs.config import CfgNode as CN
7
+ import argparse
8
+
9
+ # used to convert dart angle to board number
10
+ BOARD_DICT = {
11
+ 0: '13', 1: '4', 2: '18', 3: '1', 4: '20', 5: '5', 6: '12', 7: '9', 8: '14', 9: '11',
12
+ 10: '8', 11: '16', 12: '7', 13: '19', 14: '3', 15: '17', 16: '2', 17: '15', 18: '10', 19: '6'
13
+ }
14
+
15
+
16
+ def crop_board(img_path, bbox=None, crop_info=(0, 0, 0), crop_pad=1.1):
17
+ img = cv2.imread(img_path)
18
+ if bbox is None:
19
+ x, y, r = crop_info
20
+ r = int(r * crop_pad)
21
+ bbox = [y-r, y+r, x-r, x+r]
22
+ crop = img[bbox[0]:bbox[1], bbox[2]:bbox[3]]
23
+ return crop, bbox
24
+
25
+
26
+ def on_click(event, x, y, flags, param):
27
+ global xy, img_copy
28
+ h, w = img_copy.shape[:2]
29
+ if event == cv2.EVENT_LBUTTONDOWN:
30
+ if len(xy) < 7:
31
+ xy.append([x/w, y/h])
32
+ print_xy()
33
+ else:
34
+ print('Already annotated 7 points.')
35
+
36
+
37
+ def print_xy():
38
+ global xy
39
+ names = {
40
+ 0: 'cal_1', 1: 'cal_2', 2: 'cal_3', 3: 'cal_4',
41
+ 4: 'dart_1', 5: 'dart_2', 6: 'dart_3'}
42
+ print('{}: {}'.format(names[len(xy)-1], xy[-1]))
43
+
44
+
45
+ def get_ellipses(xy, r_double=0.17, r_treble=0.1074):
46
+ c = np.mean(xy[:4], axis=0)
47
+ a1_double = ((xy[2][0] - xy[3][0]) ** 2 + (xy[2][1] - xy[3][1]) ** 2) ** 0.5 / 2
48
+ a2_double = ((xy[0][0] - xy[1][0]) ** 2 + (xy[0][1] - xy[1][1]) ** 2) ** 0.5 / 2
49
+ a1_treble = a1_double * (r_treble / r_double)
50
+ a2_treble = a2_double * (r_treble / r_double)
51
+ angle = np.arctan((xy[3, 1] - c[1]) / (xy[3, 0] - c[0])) / np.pi * 180
52
+ return c, [a1_double, a2_double], [a1_treble, a2_treble], angle
53
+
54
+
55
+ def draw_ellipses(img, xy, num_pts=7):
56
+ # img must be uint8
57
+ xy = np.array(xy)
58
+ if xy.shape[0] > num_pts:
59
+ xy = xy.reshape((-1, 2))
60
+ if np.mean(xy) < 1:
61
+ h, w = img.shape[:2]
62
+ xy[:, 0] *= w
63
+ xy[:, 1] *= h
64
+ c, a_double, a_treble, angle = get_ellipses(xy)
65
+ angle = np.arctan((xy[3,1]-c[1])/(xy[3,0]-c[0]))/np.pi*180
66
+ cv2.ellipse(img, (int(round(c[0])), int(round(c[1]))),
67
+ (int(round(a_double[0])), int(round(a_double[1]))),
68
+ int(round(angle)), 0, 360, (255, 255, 255))
69
+ cv2.ellipse(img, (int(round(c[0])), int(round(c[1]))),
70
+ (int(round(a_treble[0])), int(round(a_treble[1]))),
71
+ int(round(angle)), 0, 360, (255, 255, 255))
72
+ return img
73
+
74
+
75
+ def get_circle(xy):
76
+ c = np.mean(xy[:4], axis=0)
77
+ r = np.mean(np.linalg.norm(xy[:4] - c, axis=-1))
78
+ return c, r
79
+
80
+
81
+ def board_radii(r_d, cfg):
82
+ r_t = r_d * (cfg.board.r_treble / cfg.board.r_double) # treble radius, in px
83
+ r_ib = r_d * (cfg.board.r_inner_bull / cfg.board.r_double) # inner bull radius, in px
84
+ r_ob = r_d * (cfg.board.r_outer_bull / cfg.board.r_double) # outer bull radius, in px
85
+ w_dt = cfg.board.w_double_treble * (r_d / cfg.board.r_double) # width of double and treble
86
+ return r_t, r_ob, r_ib, w_dt
87
+
88
+
89
+ def draw_circles(img, xy, cfg, color=(255, 255, 255)):
90
+ c, r_d = get_circle(xy) # double radius
91
+ r_t, r_ob, r_ib, w_dt = board_radii(r_d, cfg)
92
+ for r in [r_d, r_d - w_dt, r_t, r_t - w_dt, r_ib, r_ob]:
93
+ cv2.circle(img, (round(c[0]), round(c[1])), round(r), color)
94
+ return img
95
+
96
+
97
+ def transform(xy, img=None, angle=9, M=None):
98
+
99
+ if xy.shape[-1] == 3:
100
+ has_vis = True
101
+ vis = xy[:, 2:]
102
+ xy = xy[:, :2]
103
+ else:
104
+ has_vis = False
105
+
106
+ if img is not None and np.mean(xy[:4]) < 1:
107
+ h, w = img.shape[:2]
108
+ xy *= [[w, h]]
109
+
110
+ if M is None:
111
+ c, r = get_circle(xy) # not necessarily a circle
112
+ # c is center of 4 calibration points, r is mean distance from center to calibration points
113
+
114
+ src_pts = xy[:4].astype(np.float32)
115
+ dst_pts = np.array([
116
+ [c[0] - r * np.sin(np.deg2rad(angle)), c[1] - r * np.cos(np.deg2rad(angle))],
117
+ [c[0] + r * np.sin(np.deg2rad(angle)), c[1] + r * np.cos(np.deg2rad(angle))],
118
+ [c[0] - r * np.cos(np.deg2rad(angle)), c[1] + r * np.sin(np.deg2rad(angle))],
119
+ [c[0] + r * np.cos(np.deg2rad(angle)), c[1] - r * np.sin(np.deg2rad(angle))]
120
+ ]).astype(np.float32)
121
+ M = cv2.getPerspectiveTransform(src_pts, dst_pts)
122
+
123
+ xyz = np.concatenate((xy, np.ones((xy.shape[0], 1))), axis=-1).astype(np.float32)
124
+ xyz_dst = np.matmul(M, xyz.T).T
125
+ xy_dst = xyz_dst[:, :2] / xyz_dst[:, 2:]
126
+
127
+ if img is not None:
128
+ img = cv2.warpPerspective(img.copy(), M, (img.shape[1], img.shape[0]))
129
+ xy_dst /= [[w, h]]
130
+
131
+ if has_vis:
132
+ xy_dst = np.concatenate([xy_dst, vis], axis=-1)
133
+
134
+ return xy_dst, img, M
135
+
136
+
137
+ def get_dart_scores(xy, cfg, numeric=False):
138
+ valid_cal_pts = xy[:4][(xy[:4, 0] > 0) & (xy[:4, 1] > 0)]
139
+ if xy.shape[0] <= 4 or valid_cal_pts.shape[0] < 4: # missing calibration point
140
+ return []
141
+ xy, _, _ = transform(xy.copy(), angle=0)
142
+ c, r_d = get_circle(xy)
143
+ r_t, r_ob, r_ib, w_dt = board_radii(r_d, cfg)
144
+ xy -= c
145
+ angles = np.arctan2(-xy[4:, 1], xy[4:, 0]) / np.pi * 180
146
+ angles = [a + 360 if a < 0 else a for a in angles] # map to 0-360
147
+ distances = np.linalg.norm(xy[4:], axis=-1)
148
+ scores = []
149
+ for angle, dist in zip(angles, distances):
150
+ if dist > r_d:
151
+ scores.append('0')
152
+ elif dist <= r_ib:
153
+ scores.append('DB')
154
+ elif dist <= r_ob:
155
+ scores.append('B')
156
+ else:
157
+ number = BOARD_DICT[int(angle / 18)]
158
+ if dist <= r_d and dist > r_d - w_dt:
159
+ scores.append('D' + number)
160
+ elif dist <= r_t and dist > r_t - w_dt:
161
+ scores.append('T' + number)
162
+ else:
163
+ scores.append(number)
164
+ if numeric:
165
+ for i, s in enumerate(scores):
166
+ if 'B' in s:
167
+ if 'D' in s:
168
+ scores[i] = 50
169
+ else:
170
+ scores[i] = 25
171
+ else:
172
+ if 'D' in s or 'T' in s:
173
+ scores[i] = int(s[1:])
174
+ scores[i] = scores[i] * 2 if 'D' in s else scores[i] * 3
175
+ else:
176
+ scores[i] = int(s)
177
+ return scores
178
+
179
+
180
+ def draw(img, xy, cfg, circles, score, color=(255, 255, 0)):
181
+ xy = np.array(xy)
182
+ if xy.shape[0] > 7:
183
+ xy = xy.reshape((-1, 2))
184
+ if np.mean(xy) < 1:
185
+ h, w = img.shape[:2]
186
+ xy[:, 0] *= w
187
+ xy[:, 1] *= h
188
+ if xy.shape[0] >= 4 and circles:
189
+ img = draw_circles(img, xy, cfg)
190
+ if xy.shape[0] > 4 and score:
191
+ scores = get_dart_scores(xy, cfg)
192
+ font = cv2.FONT_HERSHEY_SIMPLEX
193
+ font_scale = 0.5
194
+ line_type = 1
195
+ for i, [x, y] in enumerate(xy):
196
+ if i < 4:
197
+ c = (0, 255, 0) # green
198
+ else:
199
+ c = color # cyan
200
+ x = int(round(x))
201
+ y = int(round(y))
202
+ if i >= 4:
203
+ cv2.circle(img, (x, y), 1, c, 1)
204
+ if score:
205
+ txt = str(scores[i - 4])
206
+ else:
207
+ txt = str(i + 1)
208
+ cv2.putText(img, txt, (x + 8, y), font,
209
+ font_scale, c, line_type)
210
+ else:
211
+ cv2.circle(img, (x, y), 1, c, 1)
212
+ cv2.putText(img, str(i + 1), (x + 8, y), font,
213
+ font_scale, c, line_type)
214
+ return img
215
+
216
+
217
+ def adjust_xy(idx):
218
+ global xy, img_copy
219
+ key = cv2.waitKey(0) & 0xFF
220
+ xy = np.array(xy)
221
+ h, w = img_copy.shape[:2]
222
+ xy[:, 0] *= w; xy[:, 1] *= h
223
+ if key == 52: # one pixel left
224
+ if idx == -1:
225
+ xy[:, 0] -= 1
226
+ else:
227
+ xy[idx, 0] -= 1
228
+ if key == 56: # one pixel up
229
+ if idx == -1:
230
+ xy[:, 1] -= 1
231
+ else:
232
+ xy[idx, 1] -= 1
233
+ if key == 54: # one pixel right
234
+ if idx == -1:
235
+ xy[:, 0] += 1
236
+ else:
237
+ xy[idx, 0] += 1
238
+ if key == 50: # one pixel down
239
+ if idx == -1:
240
+ xy[:, 1] += 1
241
+ else:
242
+ xy[idx, 1] += 1
243
+ xy[:, 0] /= w; xy[:, 1] /= h
244
+ xy = xy.tolist()
245
+
246
+
247
+ def add_last_dart(annot, data_path, folder):
248
+ csv_path = osp.join(data_path, 'annotations', folder + '.csv')
249
+ if osp.isfile(csv_path):
250
+ dart_labels = []
251
+ csv = pd.read_csv(csv_path)
252
+ for idx in csv.index.values:
253
+ for c in csv.columns:
254
+ dart_labels.append(str(csv.loc[idx, c]))
255
+ annot['last_dart'] = dart_labels
256
+ return annot
257
+
258
+
259
+ def get_bounding_box(img_path, scale=0.2):
260
+ img = cv2.imread(img_path)
261
+ img_resized = cv2.resize(img, None, fx=scale, fy=scale)
262
+ h, w = img_resized.shape[:2]
263
+ xy_bbox = []
264
+
265
+ def on_click_bbox(event, x, y, flags, param):
266
+ if event == cv2.EVENT_LBUTTONDOWN:
267
+ if len(xy_bbox) < 2:
268
+ xy_bbox.append([
269
+ round((x / w) * img.shape[1]),
270
+ round((y / h) * img.shape[0])])
271
+
272
+ window = 'get bbox'
273
+ cv2.namedWindow(window)
274
+ cv2.setMouseCallback(window, on_click_bbox)
275
+ while len(xy_bbox) < 2:
276
+ # print(xy_bbox)
277
+ cv2.imshow(window, img_resized)
278
+ key = cv2.waitKey(100)
279
+ if key == ord('q'): # quit
280
+ cv2.destroyAllWindows()
281
+ break
282
+ cv2.destroyAllWindows()
283
+ assert len(xy_bbox) == 2, 'click 2 points to get bounding box'
284
+ xy_bbox = np.array(xy_bbox)
285
+ # bbox = [y1 y2 x1 x2]
286
+ bbox = [min(xy_bbox[:, 1]), max(xy_bbox[:, 1]), min(xy_bbox[:, 0]), max(xy_bbox[:, 0])]
287
+ return bbox
288
+
289
+
290
+ def main(cfg, folder, scale, draw_circles, dart_score=True):
291
+ global xy, img_copy
292
+ img_dir = osp.join(cfg.data.path, 'images', folder)
293
+ imgs = sorted(os.listdir(img_dir))
294
+ annot_path = osp.join(cfg.data.path, 'annotations', folder + '.pkl')
295
+ if osp.isfile(annot_path):
296
+ annot = pd.read_pickle(annot_path)
297
+ else:
298
+ annot = pd.DataFrame(columns=['img_name', 'bbox', 'xy'])
299
+ annot['img_name'] = imgs
300
+ annot['bbox'] = None
301
+ annot['xy'] = None
302
+ annot = add_last_dart(annot, cfg.data.path, folder)
303
+
304
+ i = 0
305
+ for j in range(len(annot)):
306
+ a = annot.iloc[j,:]
307
+ if a['bbox'] is not None:
308
+ i = j
309
+
310
+ while i < len(imgs):
311
+ xy = []
312
+ a = annot.iloc[i,:]
313
+ print('Annotating {}'.format(a['img_name']))
314
+ if a['bbox'] is None:
315
+ if i == 0:
316
+ bbox = get_bounding_box(osp.join(img_dir, a['img_name']))
317
+ if i > 0:
318
+ last_a = annot.iloc[i-1,:]
319
+ if last_a['xy'] is not None:
320
+ xy = last_a['xy'].copy()
321
+ else:
322
+ xy = []
323
+ else:
324
+ bbox, xy = a['bbox'], a['xy']
325
+
326
+ crop, _ = crop_board(osp.join(img_dir, a['img_name']), bbox=bbox)
327
+ crop = cv2.resize(crop, (int(crop.shape[1] * scale), int(crop.shape[0] * scale)))
328
+ cv2.putText(crop, '{}/{} {}'.format(i+1, len(annot), a['img_name']), (0, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
329
+ img_copy = crop.copy()
330
+
331
+ cv2.namedWindow(folder)
332
+ cv2.setMouseCallback(folder, on_click)
333
+ while True:
334
+ img_copy = draw(img_copy, xy, cfg, draw_circles, dart_score)
335
+ cv2.imshow(folder, img_copy)
336
+ key = cv2.waitKey(100) & 0xFF # update every 100 ms
337
+
338
+ if key == ord('q'): # quit
339
+ cv2.destroyAllWindows()
340
+ i = len(imgs)
341
+ break
342
+
343
+ if key == ord('b'): # draw new bounding box
344
+ idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
345
+ annot.at[idx, 'bbox'] = get_bounding_box(osp.join(img_dir, a['img_name']), scale)
346
+ break
347
+
348
+ if key == ord('.'):
349
+ i += 1
350
+ img_copy = crop.copy()
351
+ break
352
+
353
+ if key == ord(','):
354
+ if i > 0:
355
+ i += -1
356
+ img_copy = crop.copy()
357
+ break
358
+
359
+ if key == ord('z'): # undo keypoint
360
+ xy = xy[:-1]
361
+ img_copy = crop.copy()
362
+
363
+ if key == ord('x'): # reset annotation
364
+ idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
365
+ annot.at[idx, 'xy'] = None,
366
+ annot.at[idx, 'bbox'] = None
367
+ annot.to_pickle(annot_path)
368
+ break
369
+
370
+ if key == ord('d'): # delete img
371
+ print('Are you sure you want to delete this image? (y/n)')
372
+ key = cv2.waitKey(0) & 0xFF
373
+ if key == ord('y'):
374
+ idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
375
+ annot = annot.drop([idx])
376
+ annot.to_pickle(annot_path)
377
+ os.remove(osp.join(img_dir, a['img_name']))
378
+ print('Deleted image {}'.format(a['img_name']))
379
+ break
380
+ else:
381
+ print('Image not deleted.')
382
+ continue
383
+
384
+ if key == ord('a'): # accept keypoints
385
+ idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
386
+ annot.at[idx, 'xy'] = xy
387
+ annot.at[idx, 'bbox'] = bbox
388
+ annot.to_pickle(annot_path)
389
+ i += 1
390
+ break
391
+
392
+ if key in [ord('1'), ord('2'), ord('3'), ord('4'), ord('5'), ord('6'), ord('7'), ord('0')]:
393
+ adjust_xy(idx=key - 49) # ord('1') = 49
394
+ img_copy = crop.copy()
395
+ continue
396
+
397
+
398
+ if __name__ == '__main__':
399
+ import sys
400
+ sys.path.append('../../')
401
+ parser = argparse.ArgumentParser()
402
+ parser.add_argument('-f', '--img-folder', default='d2_04_05_2020')
403
+ parser.add_argument('-s', '--scale', type=float, default=0.5)
404
+ parser.add_argument('-d', '--draw-circles', action='store_true')
405
+ args = parser.parse_args()
406
+
407
+ cfg = CN(new_allowed=True)
408
+ cfg.merge_from_file('../configs/tiny480_20e.yaml')
409
+
410
+ main(cfg, args.img_folder, args.scale, args.draw_circles)
dataset/labels.json ADDED
The diff for this file is too large to render. See raw diff
 
dataset/labels.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bbde9f5cbfa1d623884c86210154867f99d3589309cc062476884952ac4c935
3
+ size 2791670
src/data.rs ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use burn::prelude::*;
2
+ use serde::{Deserialize, Serialize};
3
+ use std::collections::HashMap;
4
+ use std::fs::File;
5
+ use std::io::BufReader;
6
+
7
+ #[derive(Serialize, Deserialize, Debug, Clone)]
8
+ pub struct Annotation {
9
+ pub img_folder: String,
10
+ pub img_name: String,
11
+ pub bbox: Vec<i32>,
12
+ pub xy: Vec<Vec<f32>>,
13
+ }
14
+
15
+ pub struct DartDataset {
16
+ pub annotations: Vec<Annotation>,
17
+ pub base_path: String,
18
+ }
19
+
20
+ impl DartDataset {
21
+ pub fn load(json_path: &str, base_path: &str) -> Self {
22
+ let file = File::open(json_path).expect("Labels JSON not found");
23
+ let reader = BufReader::new(file);
24
+ let raw_data: HashMap<String, Annotation> = serde_json::from_reader(reader).expect("JSON parse error");
25
+
26
+ let mut annotations: Vec<Annotation> = raw_data.into_values().collect();
27
+ annotations.sort_by(|a, b| a.img_name.cmp(&b.img_name));
28
+
29
+ Self {
30
+ annotations,
31
+ base_path: base_path.to_string(),
32
+ }
33
+ }
34
+ }
35
+
36
+ impl burn::data::dataset::Dataset<Annotation> for DartDataset {
37
+ fn get(&self, index: usize) -> Option<Annotation> {
38
+ self.annotations.get(index).cloned()
39
+ }
40
+
41
+ fn len(&self) -> usize {
42
+ self.annotations.len()
43
+ }
44
+ }
45
+
46
+ #[derive(Clone, Debug)]
47
+ pub struct DartBatch<B: Backend> {
48
+ pub images: Tensor<B, 4>,
49
+ pub targets: Tensor<B, 4>,
50
+ }
51
+
52
+ #[derive(Clone, Debug)]
53
+ pub struct DartBatcher<B: Backend> {
54
+ device: Device<B>,
55
+ }
56
+
57
+ use burn::data::dataloader::batcher::Batcher;
58
+
59
+ impl<B: Backend> Batcher<Annotation, DartBatch<B>> for DartBatcher<B> {
60
+ fn batch(&self, items: Vec<Annotation>) -> DartBatch<B> {
61
+ self.batch_manual(items)
62
+ }
63
+ }
64
+
65
+ impl<B: Backend> DartBatcher<B> {
66
+ pub fn new(device: Device<B>) -> Self {
67
+ Self { device }
68
+ }
69
+
70
+ pub fn batch_manual(&self, items: Vec<Annotation>) -> DartBatch<B> {
71
+ let batch_size = items.len();
72
+ let input_res: usize = 416; // Standard YOLO 416 resolution for GPU stability
73
+ let grid_size: usize = 26; // 416 / 16 (stride accumulation) = 26
74
+ let num_channels: usize = 30; // 3 anchors * (x,y,w,h,obj,p0...p4)
75
+
76
+ let mut images_list = Vec::with_capacity(batch_size);
77
+ let mut target_raw = vec![0.0f32; batch_size * num_channels * grid_size * grid_size];
78
+
79
+ for (b_idx, item) in items.iter().enumerate() {
80
+ // 1. Process Image
81
+ let path = format!("dataset/800/{}/{}", item.img_folder, item.img_name);
82
+ let img = image::open(&path).unwrap_or_else(|_| image::DynamicImage::new_rgb8(input_res as u32, input_res as u32));
83
+ let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
84
+ let pixels: Vec<f32> = resized.to_rgb8().pixels()
85
+ .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
86
+ .collect();
87
+ images_list.push(TensorData::new(pixels, [input_res, input_res, 3]));
88
+
89
+ for (i, p) in item.xy.iter().enumerate() {
90
+ let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
91
+ let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
92
+
93
+ let cls = if i < 4 { i + 1 } else { 0 };
94
+ let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
95
+
96
+ // TF order: [x,y,w,h,obj,p0..p4]
97
+ target_raw[base_idx + 0 * grid_size * grid_size] = p[0]; // X
98
+ target_raw[base_idx + 1 * grid_size * grid_size] = p[1]; // Y
99
+ target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
100
+ target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
101
+ target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
102
+ target_raw[base_idx + (5 + cls) * grid_size * grid_size] = 1.0; // Class prob
103
+ }
104
+ }
105
+
106
+ let images = Tensor::stack(
107
+ images_list.into_iter().map(|d| Tensor::<B, 3>::from_data(d, &self.device)).collect(),
108
+ 0
109
+ ).permute([0, 3, 1, 2]);
110
+
111
+ let targets = Tensor::from_data(
112
+ TensorData::new(target_raw, [batch_size, num_channels, grid_size, grid_size]),
113
+ &self.device
114
+ );
115
+
116
+ DartBatch { images, targets }
117
+ }
118
+ }
src/inference.rs ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use burn::module::Module;
2
+ use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
3
+ use burn::tensor::backend::Backend;
4
+ use burn::tensor::{Tensor, TensorData};
5
+ use crate::model::DartVisionModel;
6
+ use image::{GenericImageView, DynamicImage};
7
+
8
+ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
9
+ println!("🔍 Loading model for inference...");
10
+ let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
11
+ let model: DartVisionModel<B> = DartVisionModel::new(device);
12
+
13
+ // Load weights
14
+ let record = Recorder::load(&recorder, "model_weights".into(), device)
15
+ .expect("Failed to load weights. Make sure model_weights.bin exists.");
16
+ let model = model.load_record(record);
17
+
18
+ println!("🖼️ Processing image: {}...", image_path);
19
+ let img = image::open(image_path).expect("Failed to open image");
20
+ let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
21
+ let pixels: Vec<f32> = resized.to_rgb8().pixels()
22
+ .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
23
+ .collect();
24
+
25
+ let data = TensorData::new(pixels, [800, 800, 3]);
26
+ let input = Tensor::<B, 3>::from_data(data, device).unsqueeze::<4>().permute([0, 3, 1, 2]);
27
+
28
+ println!("🚀 Running MODEL Prediction...");
29
+ let (out16, _out32) = model.forward(input);
30
+
31
+ // Post-process out16 (size [1, 30, 100, 100])
32
+ // Decode objectness part (Channel 4 for Anchor 0)
33
+ let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
34
+
35
+ // Find highest confidence cell
36
+ let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
37
+ let confidence: f32 = max_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
38
+
39
+ println!("--------------------------------------------------");
40
+ println!("📊 RESULTS FOR: {}", image_path);
41
+ println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
42
+
43
+ if confidence > 0.05 {
44
+ println!("✅ Model found something! Confidence Score: {:.4}", confidence);
45
+ } else {
46
+ println!("⚠️ Model confidence is too low. Training incomplete?");
47
+ }
48
+ println!("--------------------------------------------------");
49
+ }
src/loss.rs ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use burn::tensor::backend::Backend;
2
+ use burn::tensor::Tensor;
3
+
4
+ pub fn diou_loss<B: Backend>(
5
+ bboxes_pred: Tensor<B, 4>,
6
+ target: Tensor<B, 4>,
7
+ ) -> Tensor<B, 1> {
8
+ // 1. Reshape to separate anchors: [Batch, 3, 10, H, W]
9
+ let [batch, _channels, h, w] = bboxes_pred.dims();
10
+ let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
11
+ let t = target.reshape([batch, 3, 10, h, w]);
12
+
13
+ // 2. Objectness (Channel 4)
14
+ let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
15
+ let obj_target = t.clone().narrow(2, 4, 1);
16
+
17
+ let eps = 1e-7;
18
+ // Positive loss (where an object exists)
19
+ let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
20
+ // Negative loss (where no object exists)
21
+ let neg_loss = obj_target.clone().neg().add_scalar(1.0).mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
22
+
23
+ // Weight positive samples 10x more to fight imbalance (typical YOLO trick)
24
+ let obj_loss = pos_loss.mul_scalar(20.0).add(neg_loss).mean();
25
+
26
+ // 3. Class (Channels 5-9) - Only learn when object exists
27
+ let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
28
+ let cls_target = t.clone().narrow(2, 5, 5);
29
+ let class_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg()
30
+ .mul(obj_target.clone()) // Only count where object exists
31
+ .mean()
32
+ .mul_scalar(5.0); // Boost class learning
33
+
34
+ // 4. Coordinates (Channels 0-3) - Only learn when object exists
35
+ let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
36
+ let b_xy_target = t.clone().narrow(2, 0, 2);
37
+ let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(5.0);
38
+
39
+ let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
40
+ let b_wh_target = t.clone().narrow(2, 2, 2);
41
+ let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0).mul(obj_target).mean().mul_scalar(5.0);
42
+
43
+ obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
44
+ }
src/main.rs ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use crate::model::DartVisionModel;
2
+ use crate::server::start_gui;
3
+ use crate::train::{train, TrainingConfig};
4
+ use burn::backend::wgpu::WgpuDevice;
5
+ use burn::backend::Wgpu;
6
+ use burn::prelude::*;
7
+ use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
8
+
9
+ pub mod data;
10
+ pub mod loss;
11
+ pub mod model;
12
+ pub mod scoring;
13
+ pub mod server;
14
+ pub mod train;
15
+
16
+ fn main() {
17
+ let device = WgpuDevice::default();
18
+ let args: Vec<String> = std::env::args().collect();
19
+
20
+ if args.len() > 1 && args[1] == "gui" {
21
+ println!("🌐 [Burn-DartVision] Starting Professional Dashboard...");
22
+ tokio::runtime::Builder::new_multi_thread()
23
+ .enable_all()
24
+ .build()
25
+ .unwrap()
26
+ .block_on(start_gui(device));
27
+ } else if args.len() > 1 && args[1] == "test" {
28
+ let img_path = if args.len() > 2 { &args[2] } else { "test.jpg" };
29
+ test_model(device, img_path);
30
+ } else {
31
+ println!("🚀 [Burn-DartVision] Starting Full Project Training...");
32
+ let dataset_path = "dataset/labels.json";
33
+
34
+ let config = TrainingConfig {
35
+ num_epochs: 10,
36
+ batch_size: 1,
37
+ lr: 1e-3,
38
+ };
39
+
40
+ train::<burn::backend::Autodiff<Wgpu>>(device, dataset_path, config);
41
+ }
42
+ }
43
+
44
+ fn test_model(device: WgpuDevice, img_path: &str) {
45
+ println!("🔍 Testing model on: {}", img_path);
46
+ let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
47
+ let model = DartVisionModel::<Wgpu>::new(&device);
48
+
49
+ let record = match recorder.load("model_weights".into(), &device) {
50
+ Ok(r) => r,
51
+ Err(_) => {
52
+ println!("⚠️ Weights not found, using initial model.");
53
+ model.clone().into_record()
54
+ }
55
+ };
56
+ let model = model.load_record(record);
57
+
58
+ let img = image::open(img_path).unwrap_or_else(|_| {
59
+ println!("❌ Image not found at {}. Using random tensor.", img_path);
60
+ image::DynamicImage::new_rgb8(416, 416)
61
+ });
62
+ let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
63
+ let pixels: Vec<f32> = resized
64
+ .to_rgb8()
65
+ .pixels()
66
+ .flat_map(|p| {
67
+ vec![
68
+ p[0] as f32 / 255.0,
69
+ p[1] as f32 / 255.0,
70
+ p[2] as f32 / 255.0,
71
+ ]
72
+ })
73
+ .collect();
74
+
75
+ let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
76
+ let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &device).permute([0, 3, 1, 2]);
77
+ let (out, _): (Tensor<Wgpu, 4>, _) = model.forward(input);
78
+
79
+ let obj = burn::tensor::activation::sigmoid(out.clone().narrow(1, 4, 1));
80
+ let (max_val, _) = obj.reshape([1, 676]).max_dim_with_indices(1);
81
+
82
+ let score = max_val
83
+ .to_data()
84
+ .convert::<f32>()
85
+ .as_slice::<f32>()
86
+ .unwrap()[0];
87
+ println!("📊 Max Objectness Score: {:.6}", score);
88
+ }
src/model.rs ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use burn::nn::conv::{Conv2d, Conv2dConfig};
2
+ use burn::nn::{BatchNorm, BatchNormConfig};
3
+ use burn::module::Module;
4
+ use burn::tensor::backend::Backend;
5
+ use burn::tensor::Tensor;
6
+ use burn::nn::PaddingConfig2d;
7
+ use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
8
+
9
+ #[derive(Module, Debug)]
10
+ pub struct ConvBlock<B: Backend> {
11
+ conv: Conv2d<B>,
12
+ bn: BatchNorm<B, 2>,
13
+ }
14
+
15
+ impl<B: Backend> ConvBlock<B> {
16
+ pub fn new(in_channels: usize, out_channels: usize, kernel_size: [usize; 2], device: &B::Device) -> Self {
17
+ let config = Conv2dConfig::new([in_channels, out_channels], kernel_size)
18
+ .with_padding(PaddingConfig2d::Same);
19
+ let conv = config.init(device);
20
+ let bn = BatchNormConfig::new(out_channels).init(device);
21
+ Self { conv, bn }
22
+ }
23
+
24
+ pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
25
+ let x = self.conv.forward(x);
26
+ let x = self.bn.forward(x);
27
+ burn::tensor::activation::leaky_relu(x, 0.1)
28
+ }
29
+ }
30
+
31
+ #[derive(Module, Debug)]
32
+ pub struct DartVisionModel<B: Backend> {
33
+ // Lean architecture: High resolution (800x800) but low channel count to fix GPU OOM
34
+ l1: ConvBlock<B>, // 3 -> 16
35
+ p1: MaxPool2d,
36
+ l2: ConvBlock<B>, // 16 -> 16
37
+ p2: MaxPool2d,
38
+ l3: ConvBlock<B>, // 16 -> 32
39
+ p3: MaxPool2d,
40
+ l4: ConvBlock<B>, // 32 -> 32
41
+ p4: MaxPool2d,
42
+ l5: ConvBlock<B>, // 32 -> 64
43
+ l6: ConvBlock<B>, // 64 -> 64
44
+
45
+ head_32: Conv2d<B>, // Final detection head
46
+ }
47
+
48
+ impl<B: Backend> DartVisionModel<B> {
49
+ pub fn new(device: &B::Device) -> Self {
50
+ let l1 = ConvBlock::new(3, 16, [3, 3], device);
51
+ let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
52
+
53
+ let l2 = ConvBlock::new(16, 16, [3, 3], device);
54
+ let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
55
+
56
+ let l3 = ConvBlock::new(16, 32, [3, 3], device);
57
+ let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
58
+
59
+ let l4 = ConvBlock::new(32, 32, [3, 3], device);
60
+ let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
61
+
62
+ let l5 = ConvBlock::new(32, 64, [3, 3], device);
63
+ let l6 = ConvBlock::new(64, 64, [3, 3], device);
64
+
65
+ // 30 channels = 3 anchors * (x,y,w,h,obj,p0...p4)
66
+ let head_32 = Conv2dConfig::new([64, 30], [1, 1]).init(device);
67
+
68
+ Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
69
+ }
70
+
71
+ pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4>) {
72
+ let x = self.l1.forward(x); // 800
73
+ let x = self.p1.forward(x); // 400
74
+ let x = self.l2.forward(x); // 400
75
+ let x = self.p2.forward(x); // 200
76
+ let x = self.l3.forward(x); // 200
77
+ let x = self.p3.forward(x); // 100
78
+ let x = self.l4.forward(x); // 100
79
+ let x = self.p4.forward(x); // 50
80
+
81
+ let x50 = self.l5.forward(x); // 50
82
+ let x50 = self.l6.forward(x50); // 50
83
+
84
+ let out50 = self.head_32.forward(x50);
85
+
86
+ (out50.clone(), out50)
87
+ }
88
+ }
src/scoring.rs ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use std::collections::HashMap;
2
+
3
+ pub struct ScoringConfig {
4
+ pub r_double: f32,
5
+ pub r_treble: f32,
6
+ pub r_outer_bull: f32,
7
+ pub r_inner_bull: f32,
8
+ pub w_double_treble: f32,
9
+ }
10
+
11
+ impl Default for ScoringConfig {
12
+ fn default() -> Self {
13
+ Self {
14
+ r_double: 0.170,
15
+ r_treble: 0.1074,
16
+ r_outer_bull: 0.0159,
17
+ r_inner_bull: 0.00635,
18
+ w_double_treble: 0.01,
19
+ }
20
+ }
21
+ }
22
+
23
+ pub fn get_board_dict() -> HashMap<i32, &'static str> {
24
+ let mut m = HashMap::new();
25
+ // BDO standard mapping based on degrees
26
+ let slices = ["13", "4", "18", "1", "20", "5", "12", "9", "14", "11", "8", "16", "7", "19", "3", "17", "2", "15", "10", "6"];
27
+ for (i, &s) in slices.iter().enumerate() {
28
+ m.insert(i as i32, s);
29
+ }
30
+ m
31
+ }
32
+
33
+ pub fn calculate_dart_score(cal_pts: &[[f32; 2]], dart_pt: &[f32; 2], config: &ScoringConfig) -> (i32, String) {
34
+ // 1. Calculate Center (Average of 4 calibration points)
35
+ let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
36
+ let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
37
+
38
+ // 2. Calculate average radius to boundary (doubles wire)
39
+ let avg_r_px = cal_pts.iter()
40
+ .map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
41
+ .sum::<f32>() / 4.0;
42
+
43
+ // 3. Relative distance of dart from center
44
+ let dx = dart_pt[0] - cx;
45
+ let dy = dart_pt[1] - cy;
46
+ let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
47
+
48
+ // Scale distance relative to BDO double radius
49
+ let dist_scaled = (dist_px / avg_r_px) * config.r_double;
50
+
51
+ // 4. Calculate Angle (0 is 3 o'clock, CCW)
52
+ let mut angle_deg = (-dy).atan2(dx).to_degrees();
53
+ if angle_deg < 0.0 { angle_deg += 360.0; }
54
+
55
+ // Board is rotated such that 20 is at top (90 deg)
56
+ // Sector width is 18 deg. Sector 20 is centered at 90 deg.
57
+ // 90 deg is index 4 in slices (13, 4, 18, 1, 20...)
58
+ // Each index is 18 deg. Offset = 4 * 18 = 72? No.
59
+ // Let's use the standard mapping: (angle / 18)
60
+ // Wait, the BOARD_DICT in Python uses int(angle / 18) where angle is 0-360.
61
+ // We need to match the slice orientation.
62
+ let board_dict = get_board_dict();
63
+ let sector_idx = ((angle_deg / 18.0).floor() as i32) % 20;
64
+ let sector_num = board_dict.get(&sector_idx).unwrap_or(&"0");
65
+
66
+ // 5. Determine multipliers based on scaled distance
67
+ let r_t = config.r_treble;
68
+ let r_d = config.r_double;
69
+ let w = config.w_double_treble;
70
+ let r_ib = config.r_inner_bull;
71
+ let r_ob = config.r_outer_bull;
72
+
73
+ if dist_scaled > r_d {
74
+ (0, "Miss".to_string())
75
+ } else if dist_scaled <= r_ib {
76
+ (50, "DB".to_string())
77
+ } else if dist_scaled <= r_ob {
78
+ (25, "B".to_string())
79
+ } else if dist_scaled <= r_d && dist_scaled > (r_d - w) {
80
+ let val = sector_num.parse::<i32>().unwrap_or(0);
81
+ (val * 2, format!("D{}", sector_num))
82
+ } else if dist_scaled <= r_t && dist_scaled > (r_t - w) {
83
+ let val = sector_num.parse::<i32>().unwrap_or(0);
84
+ (val * 3, format!("T{}", sector_num))
85
+ } else {
86
+ let val = sector_num.parse::<i32>().unwrap_or(0);
87
+ (val, sector_num.to_string())
88
+ }
89
+ }
src/server.rs ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use axum::{
2
+ extract::{DefaultBodyLimit, Multipart, State},
3
+ response::{Html, Json},
4
+ routing::{get, post},
5
+ Router,
6
+ };
7
+ use burn::prelude::*;
8
+ use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
9
+ use burn::backend::Wgpu;
10
+ use burn::backend::wgpu::WgpuDevice;
11
+ use crate::model::DartVisionModel;
12
+ use serde_json::json;
13
+ use std::net::SocketAddr;
14
+ use std::sync::Arc;
15
+ use tokio::sync::{mpsc, oneshot};
16
+ use tower_http::cors::CorsLayer;
17
+
18
+ #[derive(Debug)]
19
+ struct PredictResult {
20
+ confidence: f32,
21
+ keypoints: Vec<f32>,
22
+ scores: Vec<String>,
23
+ }
24
+
25
+ struct PredictRequest {
26
+ image_bytes: Vec<u8>,
27
+ response_tx: oneshot::Sender<PredictResult>,
28
+ }
29
+
30
+ pub async fn start_gui(device: WgpuDevice) {
31
+ let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
32
+ println!("🚀 [DartVision-GUI] Starting on http://127.0.0.1:8080",);
33
+
34
+ let (tx, mut rx) = mpsc::channel::<PredictRequest>(10);
35
+
36
+ let worker_device = device.clone();
37
+ std::thread::spawn(move || {
38
+ let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
39
+ let model = DartVisionModel::<Wgpu>::new(&worker_device);
40
+ let record = match recorder.load("model_weights".into(), &worker_device) {
41
+ Ok(r) => r,
42
+ Err(_) => {
43
+ println!("⚠️ [DartVision] No 'model_weights.bin' yet. Using initial weights...");
44
+ model.clone().into_record()
45
+ }
46
+ };
47
+ let model = model.load_record(record);
48
+
49
+ while let Some(req) = rx.blocking_recv() {
50
+ let img = image::load_from_memory(&req.image_bytes).unwrap();
51
+ let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
52
+ let pixels: Vec<f32> = resized.to_rgb8().pixels()
53
+ .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
54
+ .collect();
55
+
56
+ let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
57
+ let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
58
+
59
+ let (out16, _) = model.forward(input);
60
+
61
+ let mut final_points = vec![0.0f32; 8]; // 4 corners
62
+ let mut max_conf = 0.0f32;
63
+
64
+ // 1. Extract Objectness with Sigmoid
65
+ let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
66
+
67
+ // 2. Extract best calibration corner for each class 1 to 4 (Grid 26x26 = 676)
68
+ for cls_idx in 1..=4 {
69
+ let prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5 + cls_idx, 1));
70
+ let score = obj.clone().mul(prob);
71
+ let (val, idx) = score.reshape([1, 676]).max_dim_with_indices(1);
72
+
73
+ let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
74
+ let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
75
+ let gy = f_idx / 26;
76
+ let gx = f_idx % 26;
77
+
78
+ // Use Sigmoid for Coordinates (matching new loss logic)
79
+ let px = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
80
+ .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
81
+ let py = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
82
+ .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
83
+
84
+ final_points[(cls_idx-1)*2] = px;
85
+ final_points[(cls_idx-1)*2+1] = py;
86
+ if s > max_conf { max_conf = s; }
87
+ }
88
+
89
+ // 3. Extract best dart (Class 0)
90
+ let d_prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5, 1));
91
+ let d_score = obj.clone().mul(d_prob);
92
+ let (d_val, d_idx) = d_score.reshape([1, 676]).max_dim_with_indices(1);
93
+ let ds = d_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
94
+ if ds > 0.1 {
95
+ let f_idx = d_idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
96
+ let gy = f_idx / 26;
97
+ let gx = f_idx % 26;
98
+ let dx = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
99
+ .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
100
+ let dy = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
101
+ .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
102
+ final_points.push(dx);
103
+ final_points.push(dy);
104
+ }
105
+
106
+ let mut final_scores = vec![];
107
+
108
+ // Calculate scores if we have calibration points and at least one dart
109
+ if final_points.len() >= 10 {
110
+ use crate::scoring::{calculate_dart_score, ScoringConfig};
111
+ let config = ScoringConfig::default();
112
+ let cal_pts = [
113
+ [final_points[0], final_points[1]],
114
+ [final_points[2], final_points[3]],
115
+ [final_points[4], final_points[5]],
116
+ [final_points[6], final_points[7]],
117
+ ];
118
+
119
+ for dart_chunk in final_points[8..].chunks(2) {
120
+ if dart_chunk.len() == 2 {
121
+ let dart_pt = [dart_chunk[0], dart_chunk[1]];
122
+ let (_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
123
+ final_scores.push(label);
124
+ }
125
+ }
126
+ }
127
+
128
+ println!("🎯 [Detection Result] Confidence: {:.2}%", max_conf * 100.0);
129
+ let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
130
+ for (i, pts) in final_points.chunks(2).enumerate() {
131
+ let name = class_names.get(i).unwrap_or(&"Dart");
132
+ let label = final_scores.get(i.saturating_sub(4)).cloned().unwrap_or_default();
133
+ println!(" - {}: [x: {:.3}, y: {:.3}] {}", name, pts[0], pts[1], label);
134
+ }
135
+
136
+ let _ = req.response_tx.send(PredictResult {
137
+ confidence: max_conf,
138
+ keypoints: final_points,
139
+ scores: final_scores,
140
+ });
141
+ }
142
+ });
143
+
144
+ let state = Arc::new(tx);
145
+
146
+ let app = Router::new()
147
+ .route("/", get(|| async { Html(include_str!("../static/index.html")) }))
148
+ .route("/api/predict", post(predict_handler))
149
+ .with_state(state)
150
+ .layer(DefaultBodyLimit::max(10 * 1024 * 1024))
151
+ .layer(CorsLayer::permissive());
152
+
153
+ let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
154
+ axum::serve(listener, app).await.unwrap();
155
+ }
156
+
157
+ async fn predict_handler(
158
+ State(tx): State<Arc<mpsc::Sender<PredictRequest>>>,
159
+ mut multipart: Multipart,
160
+ ) -> Json<serde_json::Value> {
161
+ while let Ok(Some(field)) = multipart.next_field().await {
162
+ if field.name() == Some("image") {
163
+ let bytes = match field.bytes().await {
164
+ Ok(b) => b.to_vec(),
165
+ Err(_) => continue,
166
+ };
167
+ let (res_tx, res_rx) = oneshot::channel();
168
+ let _ = tx.send(PredictRequest { image_bytes: bytes, response_tx: res_tx }).await;
169
+ let result = res_rx.await.unwrap_or(PredictResult { confidence: 0.0, keypoints: vec![] });
170
+
171
+ return Json(json!({
172
+ "status": "success",
173
+ "confidence": result.confidence,
174
+ "keypoints": result.keypoints,
175
+ "scores": result.scores,
176
+ "message": if result.confidence > 0.1 {
177
+ format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
178
+ } else {
179
+ "⚠️ Low confidence detection - no dart score could be verified.".to_string()
180
+ }
181
+ }));
182
+ }
183
+ }
184
+ Json(json!({"status": "error", "message": "No image field found"}))
185
+ }
src/train.rs ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use crate::data::{DartBatcher, DartDataset};
2
+ use crate::loss::diou_loss;
3
+ use crate::model::DartVisionModel;
4
+ use burn::data::dataset::Dataset; // Add this trait to scope
5
+ use burn::module::Module;
6
+ use burn::optim::{AdamConfig, GradientsParams, Optimizer};
7
+ use burn::prelude::*;
8
+ use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
9
+ use burn::tensor::backend::AutodiffBackend;
10
+
11
+ pub struct TrainingConfig {
12
+ pub num_epochs: usize,
13
+ pub batch_size: usize,
14
+ pub lr: f64,
15
+ }
16
+
17
+ pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config: TrainingConfig) {
18
+ // 1. Create Model
19
+ let mut model: DartVisionModel<B> = DartVisionModel::new(&device);
20
+
21
+ // 1.5 Load existing weights if they exist (RESUME)
22
+ let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
23
+ let weights_path = "model_weights.bin";
24
+ if std::path::Path::new(weights_path).exists() {
25
+ println!("🚀 Loading existing weights from {}...", weights_path);
26
+ let record = Recorder::load(&recorder, "model_weights".into(), &device)
27
+ .expect("Failed to load weights");
28
+ model = model.load_record(record);
29
+ }
30
+
31
+ // 2. Setup Optimizer
32
+ let mut optim = AdamConfig::new().init();
33
+
34
+ // 3. Create Dataset
35
+ println!("🔍 Mapping annotations from {}...", dataset_path);
36
+ let dataset = DartDataset::load(dataset_path, "dataset/800");
37
+ println!("📊 Dataset loaded with {} examples.", dataset.len());
38
+ let batcher = DartBatcher::new(device.clone());
39
+
40
+ // 4. Create DataLoader
41
+ println!("📦 Initializing DataLoader (Workers: 4)...");
42
+ let dataloader = burn::data::dataloader::DataLoaderBuilder::new(batcher)
43
+ .batch_size(config.batch_size)
44
+ .shuffle(42)
45
+ .num_workers(4)
46
+ .build(dataset);
47
+
48
+ // 5. Training Loop
49
+ println!(
50
+ "📈 Running FULL Training Loop (Epochs: {})...",
51
+ config.num_epochs
52
+ );
53
+
54
+ // Using a simple loop state for ownership safety
55
+ let mut current_model = model;
56
+
57
+ for epoch in 1..=config.num_epochs {
58
+ let mut model_inner = current_model; // Move into epoch
59
+ let mut batch_count = 0;
60
+
61
+ for batch in dataloader.iter() {
62
+ // Forward Pass
63
+ let (out16, _) = model_inner.forward(batch.images);
64
+
65
+ // Calculate Loss
66
+ let loss = diou_loss(out16, batch.targets);
67
+ batch_count += 1;
68
+
69
+ // Print every 10 batches to keep terminal clean and avoid stdout sync lag
70
+ if batch_count % 20 == 0 || batch_count == 1 {
71
+ println!(
72
+ " [Epoch {}] Batch {: >3} | Loss: {:.6}",
73
+ epoch,
74
+ batch_count,
75
+ loss.clone().into_scalar()
76
+ );
77
+ }
78
+
79
+ // Backward & Optimization step
80
+ let grads = loss.backward();
81
+ let grads_params = GradientsParams::from_grads(grads, &model_inner);
82
+ model_inner = optim.step(config.lr, model_inner, grads_params);
83
+
84
+ // 5.5 Periodic Save (every 100 batches and Batch 1)
85
+ if batch_count % 100 == 0 || batch_count == 1 {
86
+ model_inner.clone()
87
+ .save_file("model_weights", &recorder)
88
+ .ok();
89
+ if batch_count == 1 {
90
+ println!("🚀 [Checkpoint] Initial weights saved at Batch 1.");
91
+ } else {
92
+ println!("🚀 [Checkpoint] Saved at Batch {}.", batch_count);
93
+ }
94
+ }
95
+ }
96
+
97
+ // 6. SAVE after EACH Epoch
98
+ model_inner
99
+ .clone()
100
+ .save_file("model_weights", &recorder)
101
+ .expect("Failed to save weights");
102
+ println!("✅ Checkpoint saved: Epoch {} complete.", epoch);
103
+
104
+ current_model = model_inner; // Move back out for next epoch
105
+ }
106
+ }
static/index.html ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>DartVision AI - Smart Scoring Dashboard</title>
7
+ <link rel="preconnect" href="https://fonts.googleapis.com">
8
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
9
+ <link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
10
+ <style>
11
+ :root {
12
+ --primary: #00ff88;
13
+ --secondary: #00d4ff;
14
+ --accent: #ff4d4d;
15
+ --bg: #0a0e14;
16
+ --card-bg: rgba(255, 255, 255, 0.05);
17
+ --glass-border: rgba(255, 255, 255, 0.1);
18
+ }
19
+
20
+ * {
21
+ margin: 0;
22
+ padding: 0;
23
+ box-sizing: border-box;
24
+ font-family: 'Outfit', sans-serif;
25
+ }
26
+
27
+ body {
28
+ background: var(--bg);
29
+ color: #fff;
30
+ min-height: 100vh;
31
+ display: flex;
32
+ flex-direction: column;
33
+ align-items: center;
34
+ overflow-x: hidden;
35
+ }
36
+
37
+ .bg-glow {
38
+ position: fixed;
39
+ top: 0;
40
+ left: 0;
41
+ right: 0;
42
+ bottom: 0;
43
+ background: radial-gradient(circle at 50% 50%, #00ff8811 0%, transparent 50%),
44
+ radial-gradient(circle at 80% 20%, #00d4ff11 0%, transparent 40%);
45
+ z-index: -1;
46
+ pointer-events: none;
47
+ }
48
+
49
+ header {
50
+ padding: 2rem;
51
+ text-align: center;
52
+ }
53
+
54
+ h1 {
55
+ font-size: 3.5rem;
56
+ font-weight: 800;
57
+ background: linear-gradient(to right, var(--primary), var(--secondary));
58
+ -webkit-background-clip: text;
59
+ background-clip: text;
60
+ -webkit-text-fill-color: transparent;
61
+ margin-bottom: 0.5rem;
62
+ letter-spacing: -1px;
63
+ }
64
+
65
+ p.subtitle {
66
+ color: #8892b0;
67
+ font-size: 1.1rem;
68
+ letter-spacing: 2px;
69
+ text-transform: uppercase;
70
+ }
71
+
72
+ main {
73
+ width: 90%;
74
+ max-width: 1200px;
75
+ display: grid;
76
+ grid-template-columns: 1fr 380px;
77
+ gap: 2rem;
78
+ padding: 2rem 0;
79
+ }
80
+
81
+ .visual-area {
82
+ display: flex;
83
+ flex-direction: column;
84
+ gap: 1rem;
85
+ }
86
+
87
+ .upload-container {
88
+ background: var(--card-bg);
89
+ backdrop-filter: blur(20px);
90
+ border: 1px solid var(--glass-border);
91
+ border-radius: 24px;
92
+ padding: 1rem;
93
+ display: flex;
94
+ flex-direction: column;
95
+ align-items: center;
96
+ justify-content: center;
97
+ cursor: pointer;
98
+ transition: all 0.4s;
99
+ min-height: 600px;
100
+ position: relative;
101
+ overflow: hidden;
102
+ }
103
+
104
+ .upload-container:hover, .upload-container.drag-over {
105
+ border-color: var(--primary);
106
+ box-shadow: 0 0 40px rgba(0, 255, 136, 0.1);
107
+ }
108
+
109
+ .preview-wrapper {
110
+ position: relative;
111
+ width: 100%;
112
+ height: 100%;
113
+ display: none;
114
+ justify-content: center;
115
+ align-items: center;
116
+ }
117
+
118
+ #preview-img {
119
+ max-width: 100%;
120
+ max-height: 550px;
121
+ border-radius: 12px;
122
+ display: block;
123
+ object-fit: contain;
124
+ }
125
+
126
+ #overlay-svg {
127
+ position: absolute;
128
+ top: 0;
129
+ left: 0;
130
+ width: 100%;
131
+ height: 100%;
132
+ pointer-events: none;
133
+ }
134
+
135
+ .upload-icon { font-size: 4rem; margin-bottom: 1rem; }
136
+ .upload-text { font-size: 1.5rem; color: #ccd6f6; }
137
+ .upload-subtext { color: #8892b0; margin-top: 1rem; }
138
+
139
+ .stats-panel {
140
+ display: flex;
141
+ flex-direction: column;
142
+ gap: 1.5rem;
143
+ }
144
+
145
+ .stat-card {
146
+ background: rgba(255, 255, 255, 0.03);
147
+ border: 1px solid var(--glass-border);
148
+ border-radius: 20px;
149
+ padding: 1.5rem;
150
+ }
151
+
152
+ .stat-label { color: #8892b0; font-size: 0.9rem; margin-bottom: 0.5rem; text-transform: uppercase; }
153
+ .stat-value { font-size: 2rem; font-weight: 600; }
154
+ .conf-bar { height: 8px; background: rgba(255,255,255,0.1); border-radius: 4px; margin-top: 1rem; overflow: hidden; }
155
+ .conf-fill { height: 100%; width: 0%; background: linear-gradient(to right, var(--primary), var(--secondary)); transition: 1s; }
156
+
157
+ .loading-spinner {
158
+ width: 40px; height: 40px;
159
+ border: 3px solid rgba(0, 255, 136, 0.1);
160
+ border-top: 3px solid var(--primary);
161
+ border-radius: 50%;
162
+ animation: spin 1s linear infinite;
163
+ display: none;
164
+ position: absolute;
165
+ }
166
+
167
+ @keyframes spin { 100% { transform: rotate(360deg); } }
168
+
169
+ .keypoint-marker {
170
+ fill: var(--primary);
171
+ stroke: #fff;
172
+ stroke-width: 2px;
173
+ filter: drop-shadow(0 0 5px var(--primary));
174
+ animation: pulse 1.5s infinite;
175
+ }
176
+
177
+ @keyframes pulse {
178
+ 0% { r: 5; opacity: 1; }
179
+ 50% { r: 8; opacity: 0.7; }
180
+ 100% { r: 5; opacity: 1; }
181
+ }
182
+ </style>
183
+ </head>
184
+ <body>
185
+ <div class="bg-glow"></div>
186
+ <header>
187
+ <h1>DARTVISION <span style="font-weight: 300; opacity: 0.5;">AI</span></h1>
188
+ <p class="subtitle">Intelligent Labelling & Scoring Dashboard</p>
189
+ </header>
190
+
191
+ <main>
192
+ <section class="visual-area">
193
+ <div class="upload-container" id="drop-zone">
194
+ <div id="initial-state">
195
+ <div class="upload-icon">🎯</div>
196
+ <div class="upload-text">Drag & Drop Image</div>
197
+ <div class="upload-subtext">calibration and scoring results will appear here</div>
198
+ </div>
199
+ <div class="preview-wrapper" id="preview-wrapper">
200
+ <img id="preview-img">
201
+ <svg id="overlay-svg"></svg>
202
+ </div>
203
+ <div class="loading-spinner" id="spinner"></div>
204
+ </div>
205
+ <input type="file" id="file-input" style="display: none;" accept="image/*">
206
+ </section>
207
+
208
+ <aside class="stats-panel">
209
+ <div class="stat-card">
210
+ <div class="stat-label">Model Status</div>
211
+ <div class="stat-value" id="status-text" style="font-size: 1.2rem; color: var(--primary);">System Ready</div>
212
+ </div>
213
+ <div class="stat-card">
214
+ <div class="stat-label">AI Confidence</div>
215
+ <div class="stat-value" id="conf-val">0.0%</div>
216
+ <div class="conf-bar"><div class="conf-fill" id="conf-fill"></div></div>
217
+ </div>
218
+ <div class="stat-card">
219
+ <div class="stat-label">Detection Result</div>
220
+ <div class="stat-value" id="result-text" style="font-size: 1.1rem; opacity: 0.8;">No Image Uploaded</div>
221
+ </div>
222
+ </aside>
223
+ </main>
224
+
225
+ <script>
226
+ const dropZone = document.getElementById('drop-zone');
227
+ const fileInput = document.getElementById('file-input');
228
+ const previewImg = document.getElementById('preview-img');
229
+ const previewWrapper = document.getElementById('preview-wrapper');
230
+ const initialState = document.getElementById('initial-state');
231
+ const svgOverlay = document.getElementById('overlay-svg');
232
+ const spinner = document.getElementById('spinner');
233
+
234
+ dropZone.onclick = () => fileInput.click();
235
+ fileInput.onchange = (e) => handleFile(e.target.files[0]);
236
+
237
+ dropZone.ondragover = (e) => { e.preventDefault(); dropZone.classList.add('drag-over'); };
238
+ dropZone.ondragleave = () => dropZone.classList.remove('drag-over');
239
+ dropZone.ondrop = (e) => { e.preventDefault(); dropZone.classList.remove('drag-over'); handleFile(e.dataTransfer.files[0]); };
240
+
241
+ async function handleFile(file) {
242
+ if (!file || !file.type.startsWith('image/')) return;
243
+
244
+ const reader = new FileReader();
245
+ reader.onload = (e) => {
246
+ previewImg.src = e.target.result;
247
+ previewWrapper.style.display = 'flex';
248
+ initialState.style.display = 'none';
249
+ clearKeypoints();
250
+ };
251
+ reader.readAsDataURL(file);
252
+
253
+ spinner.style.display = 'block';
254
+ const formData = new FormData();
255
+ formData.append('image', file);
256
+
257
+ try {
258
+ const response = await fetch('/api/predict', { method: 'POST', body: formData });
259
+ const data = await response.json();
260
+ spinner.style.display = 'none';
261
+
262
+ if (data.status === 'success') {
263
+ updateUI(data);
264
+ drawKeypoints(data.keypoints);
265
+ }
266
+ } catch (err) {
267
+ spinner.style.display = 'none';
268
+ document.getElementById('status-text').innerText = 'Error';
269
+ }
270
+ }
271
+
272
+ function updateUI(data) {
273
+ const conf = (data.confidence * 100).toFixed(1);
274
+ document.getElementById('conf-val').innerText = `${conf}%`;
275
+ document.getElementById('conf-fill').style.width = `${conf}%`;
276
+
277
+ let resultHtml = `<div style="margin-top: 1rem; font-size: 0.9rem; line-height: 1.5;">${data.message}</div>`;
278
+ if (data.keypoints && data.keypoints.length >= 8) {
279
+ const names = ["Cal 1", "Cal 2", "Cal 3", "Cal 4", "Dart"];
280
+ resultHtml += `<div style="font-size: 0.8rem; opacity: 0.6; margin-top: 5px;">Results & Mapping:</div>`;
281
+ for (let i = 0; i < data.keypoints.length; i += 2) {
282
+ const classIdx = i / 2;
283
+ const name = names[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
284
+ const x = data.keypoints[i].toFixed(3);
285
+ const y = data.keypoints[i+1].toFixed(3);
286
+
287
+ // Get score label if it's a dart
288
+ let scoreLabel = "";
289
+ if (classIdx >= 4 && data.scores && data.scores[classIdx - 4]) {
290
+ scoreLabel = ` <span style="background: #ff4d4d; color: white; padding: 2px 6px; border-radius: 4px; font-weight: 800; margin-left: 10px;">${data.scores[classIdx - 4]}</span>`;
291
+ }
292
+
293
+ resultHtml += `<div style="font-size: 0.85rem; padding: 6px 0; border-bottom: 1px solid rgba(255,255,255,0.05); display: flex; align-items: center; justify-content: space-between;">
294
+ <span><span style="color: ${i < 8 ? '#00ff88' : '#ff4d4d'}">${name}</span>: [${x}, ${y}]</span>
295
+ ${scoreLabel}
296
+ </div>`;
297
+ }
298
+ }
299
+ document.getElementById('result-text').innerHTML = resultHtml;
300
+ }
301
+
302
+ function clearKeypoints() {
303
+ while (svgOverlay.firstChild) svgOverlay.removeChild(svgOverlay.firstChild);
304
+ }
305
+
306
+ function drawKeypoints(pts) {
307
+ clearKeypoints();
308
+ if (!pts || pts.length === 0) return;
309
+
310
+ // Wait for image to render to get actual display dimensions
311
+ setTimeout(() => {
312
+ const rect = previewImg.getBoundingClientRect();
313
+ const wrapperRect = previewWrapper.getBoundingClientRect();
314
+
315
+ // Offset calculation relative to wrapper
316
+ const offsetX = rect.left - wrapperRect.left;
317
+ const offsetY = rect.top - wrapperRect.top;
318
+ const width = rect.width;
319
+ const height = rect.height;
320
+
321
+ const classNames = ["Cal 1", "Cal 2", "Cal 3", "Cal 4", "Dart"];
322
+ for (let i = 0; i < pts.length; i += 2) {
323
+ const x = pts[i] * width + offsetX;
324
+ const y = pts[i+1] * height + offsetY;
325
+ const classIdx = i / 2;
326
+ const name = classNames[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
327
+
328
+ const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
329
+ circle.setAttribute("cx", x);
330
+ circle.setAttribute("cy", y);
331
+ circle.setAttribute("r", 6);
332
+ circle.setAttribute("class", "keypoint-marker");
333
+ if (classIdx >= 4) circle.style.fill = "#ff4d4d"; // Red for dart
334
+
335
+ const label = document.createElementNS("http://www.w3.org/2000/svg", "text");
336
+ label.setAttribute("x", x + 10);
337
+ label.setAttribute("y", y - 10);
338
+ label.setAttribute("fill", classIdx < 4 ? "#00ff88" : "#ff4d4d");
339
+ label.setAttribute("font-size", "14px");
340
+ label.setAttribute("font-weight", "600");
341
+ label.textContent = name;
342
+
343
+ svgOverlay.appendChild(circle);
344
+ svgOverlay.appendChild(label);
345
+ }
346
+ }, 50);
347
+ }
348
+ </script>
349
+ </body>
350
+ </html>