Spaces:
Build error
Build error
kapil commited on
Commit ·
9874885
0
Parent(s):
Initial commit
Browse files- .gitattributes +5 -0
- .gitignore +27 -0
- Cargo.lock +0 -0
- Cargo.toml +15 -0
- dataset/annotate.py +410 -0
- dataset/labels.json +0 -0
- dataset/labels.pkl +3 -0
- src/data.rs +118 -0
- src/inference.rs +49 -0
- src/loss.rs +44 -0
- src/main.rs +88 -0
- src/model.rs +88 -0
- src/scoring.rs +89 -0
- src/server.rs +185 -0
- src/train.rs +106 -0
- static/index.html +350 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Rust
|
| 2 |
+
target/
|
| 3 |
+
debug/
|
| 4 |
+
release/
|
| 5 |
+
|
| 6 |
+
# IDEs
|
| 7 |
+
.vscode/
|
| 8 |
+
.idea/
|
| 9 |
+
*.swp
|
| 10 |
+
*.swo
|
| 11 |
+
*~
|
| 12 |
+
.DS_Store
|
| 13 |
+
|
| 14 |
+
# Data & Model Weights
|
| 15 |
+
model_weights.bin
|
| 16 |
+
model_weights/
|
| 17 |
+
dataset/images/
|
| 18 |
+
dataset/cropped_images/
|
| 19 |
+
dataset/800/
|
| 20 |
+
dataset/__pycache__/
|
| 21 |
+
|
| 22 |
+
# Logs & Temp
|
| 23 |
+
*.log
|
| 24 |
+
logs/
|
| 25 |
+
tmp/
|
| 26 |
+
temp/
|
| 27 |
+
*.tmp
|
Cargo.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Cargo.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[package]
|
| 2 |
+
name = "rust_auto_score_engine"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
edition = "2021"
|
| 5 |
+
|
| 6 |
+
[dependencies]
|
| 7 |
+
burn = { version = "0.16.0", features = ["train", "wgpu"] }
|
| 8 |
+
serde = { version = "1.0", features = ["derive"] }
|
| 9 |
+
serde_json = "1.0"
|
| 10 |
+
image = "0.25"
|
| 11 |
+
ndarray = "0.16"
|
| 12 |
+
axum = { version = "0.7", features = ["multipart"] }
|
| 13 |
+
tower-http = { version = "0.5", features = ["fs", "cors"] }
|
| 14 |
+
tokio = { version = "1.0", features = ["full"] }
|
| 15 |
+
base64 = "0.22"
|
dataset/annotate.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import cv2
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from yacs.config import CfgNode as CN
|
| 7 |
+
import argparse
|
| 8 |
+
|
| 9 |
+
# used to convert dart angle to board number
|
| 10 |
+
BOARD_DICT = {
|
| 11 |
+
0: '13', 1: '4', 2: '18', 3: '1', 4: '20', 5: '5', 6: '12', 7: '9', 8: '14', 9: '11',
|
| 12 |
+
10: '8', 11: '16', 12: '7', 13: '19', 14: '3', 15: '17', 16: '2', 17: '15', 18: '10', 19: '6'
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def crop_board(img_path, bbox=None, crop_info=(0, 0, 0), crop_pad=1.1):
|
| 17 |
+
img = cv2.imread(img_path)
|
| 18 |
+
if bbox is None:
|
| 19 |
+
x, y, r = crop_info
|
| 20 |
+
r = int(r * crop_pad)
|
| 21 |
+
bbox = [y-r, y+r, x-r, x+r]
|
| 22 |
+
crop = img[bbox[0]:bbox[1], bbox[2]:bbox[3]]
|
| 23 |
+
return crop, bbox
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def on_click(event, x, y, flags, param):
|
| 27 |
+
global xy, img_copy
|
| 28 |
+
h, w = img_copy.shape[:2]
|
| 29 |
+
if event == cv2.EVENT_LBUTTONDOWN:
|
| 30 |
+
if len(xy) < 7:
|
| 31 |
+
xy.append([x/w, y/h])
|
| 32 |
+
print_xy()
|
| 33 |
+
else:
|
| 34 |
+
print('Already annotated 7 points.')
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def print_xy():
|
| 38 |
+
global xy
|
| 39 |
+
names = {
|
| 40 |
+
0: 'cal_1', 1: 'cal_2', 2: 'cal_3', 3: 'cal_4',
|
| 41 |
+
4: 'dart_1', 5: 'dart_2', 6: 'dart_3'}
|
| 42 |
+
print('{}: {}'.format(names[len(xy)-1], xy[-1]))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_ellipses(xy, r_double=0.17, r_treble=0.1074):
|
| 46 |
+
c = np.mean(xy[:4], axis=0)
|
| 47 |
+
a1_double = ((xy[2][0] - xy[3][0]) ** 2 + (xy[2][1] - xy[3][1]) ** 2) ** 0.5 / 2
|
| 48 |
+
a2_double = ((xy[0][0] - xy[1][0]) ** 2 + (xy[0][1] - xy[1][1]) ** 2) ** 0.5 / 2
|
| 49 |
+
a1_treble = a1_double * (r_treble / r_double)
|
| 50 |
+
a2_treble = a2_double * (r_treble / r_double)
|
| 51 |
+
angle = np.arctan((xy[3, 1] - c[1]) / (xy[3, 0] - c[0])) / np.pi * 180
|
| 52 |
+
return c, [a1_double, a2_double], [a1_treble, a2_treble], angle
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def draw_ellipses(img, xy, num_pts=7):
|
| 56 |
+
# img must be uint8
|
| 57 |
+
xy = np.array(xy)
|
| 58 |
+
if xy.shape[0] > num_pts:
|
| 59 |
+
xy = xy.reshape((-1, 2))
|
| 60 |
+
if np.mean(xy) < 1:
|
| 61 |
+
h, w = img.shape[:2]
|
| 62 |
+
xy[:, 0] *= w
|
| 63 |
+
xy[:, 1] *= h
|
| 64 |
+
c, a_double, a_treble, angle = get_ellipses(xy)
|
| 65 |
+
angle = np.arctan((xy[3,1]-c[1])/(xy[3,0]-c[0]))/np.pi*180
|
| 66 |
+
cv2.ellipse(img, (int(round(c[0])), int(round(c[1]))),
|
| 67 |
+
(int(round(a_double[0])), int(round(a_double[1]))),
|
| 68 |
+
int(round(angle)), 0, 360, (255, 255, 255))
|
| 69 |
+
cv2.ellipse(img, (int(round(c[0])), int(round(c[1]))),
|
| 70 |
+
(int(round(a_treble[0])), int(round(a_treble[1]))),
|
| 71 |
+
int(round(angle)), 0, 360, (255, 255, 255))
|
| 72 |
+
return img
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_circle(xy):
|
| 76 |
+
c = np.mean(xy[:4], axis=0)
|
| 77 |
+
r = np.mean(np.linalg.norm(xy[:4] - c, axis=-1))
|
| 78 |
+
return c, r
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def board_radii(r_d, cfg):
|
| 82 |
+
r_t = r_d * (cfg.board.r_treble / cfg.board.r_double) # treble radius, in px
|
| 83 |
+
r_ib = r_d * (cfg.board.r_inner_bull / cfg.board.r_double) # inner bull radius, in px
|
| 84 |
+
r_ob = r_d * (cfg.board.r_outer_bull / cfg.board.r_double) # outer bull radius, in px
|
| 85 |
+
w_dt = cfg.board.w_double_treble * (r_d / cfg.board.r_double) # width of double and treble
|
| 86 |
+
return r_t, r_ob, r_ib, w_dt
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def draw_circles(img, xy, cfg, color=(255, 255, 255)):
|
| 90 |
+
c, r_d = get_circle(xy) # double radius
|
| 91 |
+
r_t, r_ob, r_ib, w_dt = board_radii(r_d, cfg)
|
| 92 |
+
for r in [r_d, r_d - w_dt, r_t, r_t - w_dt, r_ib, r_ob]:
|
| 93 |
+
cv2.circle(img, (round(c[0]), round(c[1])), round(r), color)
|
| 94 |
+
return img
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def transform(xy, img=None, angle=9, M=None):
|
| 98 |
+
|
| 99 |
+
if xy.shape[-1] == 3:
|
| 100 |
+
has_vis = True
|
| 101 |
+
vis = xy[:, 2:]
|
| 102 |
+
xy = xy[:, :2]
|
| 103 |
+
else:
|
| 104 |
+
has_vis = False
|
| 105 |
+
|
| 106 |
+
if img is not None and np.mean(xy[:4]) < 1:
|
| 107 |
+
h, w = img.shape[:2]
|
| 108 |
+
xy *= [[w, h]]
|
| 109 |
+
|
| 110 |
+
if M is None:
|
| 111 |
+
c, r = get_circle(xy) # not necessarily a circle
|
| 112 |
+
# c is center of 4 calibration points, r is mean distance from center to calibration points
|
| 113 |
+
|
| 114 |
+
src_pts = xy[:4].astype(np.float32)
|
| 115 |
+
dst_pts = np.array([
|
| 116 |
+
[c[0] - r * np.sin(np.deg2rad(angle)), c[1] - r * np.cos(np.deg2rad(angle))],
|
| 117 |
+
[c[0] + r * np.sin(np.deg2rad(angle)), c[1] + r * np.cos(np.deg2rad(angle))],
|
| 118 |
+
[c[0] - r * np.cos(np.deg2rad(angle)), c[1] + r * np.sin(np.deg2rad(angle))],
|
| 119 |
+
[c[0] + r * np.cos(np.deg2rad(angle)), c[1] - r * np.sin(np.deg2rad(angle))]
|
| 120 |
+
]).astype(np.float32)
|
| 121 |
+
M = cv2.getPerspectiveTransform(src_pts, dst_pts)
|
| 122 |
+
|
| 123 |
+
xyz = np.concatenate((xy, np.ones((xy.shape[0], 1))), axis=-1).astype(np.float32)
|
| 124 |
+
xyz_dst = np.matmul(M, xyz.T).T
|
| 125 |
+
xy_dst = xyz_dst[:, :2] / xyz_dst[:, 2:]
|
| 126 |
+
|
| 127 |
+
if img is not None:
|
| 128 |
+
img = cv2.warpPerspective(img.copy(), M, (img.shape[1], img.shape[0]))
|
| 129 |
+
xy_dst /= [[w, h]]
|
| 130 |
+
|
| 131 |
+
if has_vis:
|
| 132 |
+
xy_dst = np.concatenate([xy_dst, vis], axis=-1)
|
| 133 |
+
|
| 134 |
+
return xy_dst, img, M
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_dart_scores(xy, cfg, numeric=False):
|
| 138 |
+
valid_cal_pts = xy[:4][(xy[:4, 0] > 0) & (xy[:4, 1] > 0)]
|
| 139 |
+
if xy.shape[0] <= 4 or valid_cal_pts.shape[0] < 4: # missing calibration point
|
| 140 |
+
return []
|
| 141 |
+
xy, _, _ = transform(xy.copy(), angle=0)
|
| 142 |
+
c, r_d = get_circle(xy)
|
| 143 |
+
r_t, r_ob, r_ib, w_dt = board_radii(r_d, cfg)
|
| 144 |
+
xy -= c
|
| 145 |
+
angles = np.arctan2(-xy[4:, 1], xy[4:, 0]) / np.pi * 180
|
| 146 |
+
angles = [a + 360 if a < 0 else a for a in angles] # map to 0-360
|
| 147 |
+
distances = np.linalg.norm(xy[4:], axis=-1)
|
| 148 |
+
scores = []
|
| 149 |
+
for angle, dist in zip(angles, distances):
|
| 150 |
+
if dist > r_d:
|
| 151 |
+
scores.append('0')
|
| 152 |
+
elif dist <= r_ib:
|
| 153 |
+
scores.append('DB')
|
| 154 |
+
elif dist <= r_ob:
|
| 155 |
+
scores.append('B')
|
| 156 |
+
else:
|
| 157 |
+
number = BOARD_DICT[int(angle / 18)]
|
| 158 |
+
if dist <= r_d and dist > r_d - w_dt:
|
| 159 |
+
scores.append('D' + number)
|
| 160 |
+
elif dist <= r_t and dist > r_t - w_dt:
|
| 161 |
+
scores.append('T' + number)
|
| 162 |
+
else:
|
| 163 |
+
scores.append(number)
|
| 164 |
+
if numeric:
|
| 165 |
+
for i, s in enumerate(scores):
|
| 166 |
+
if 'B' in s:
|
| 167 |
+
if 'D' in s:
|
| 168 |
+
scores[i] = 50
|
| 169 |
+
else:
|
| 170 |
+
scores[i] = 25
|
| 171 |
+
else:
|
| 172 |
+
if 'D' in s or 'T' in s:
|
| 173 |
+
scores[i] = int(s[1:])
|
| 174 |
+
scores[i] = scores[i] * 2 if 'D' in s else scores[i] * 3
|
| 175 |
+
else:
|
| 176 |
+
scores[i] = int(s)
|
| 177 |
+
return scores
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def draw(img, xy, cfg, circles, score, color=(255, 255, 0)):
|
| 181 |
+
xy = np.array(xy)
|
| 182 |
+
if xy.shape[0] > 7:
|
| 183 |
+
xy = xy.reshape((-1, 2))
|
| 184 |
+
if np.mean(xy) < 1:
|
| 185 |
+
h, w = img.shape[:2]
|
| 186 |
+
xy[:, 0] *= w
|
| 187 |
+
xy[:, 1] *= h
|
| 188 |
+
if xy.shape[0] >= 4 and circles:
|
| 189 |
+
img = draw_circles(img, xy, cfg)
|
| 190 |
+
if xy.shape[0] > 4 and score:
|
| 191 |
+
scores = get_dart_scores(xy, cfg)
|
| 192 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 193 |
+
font_scale = 0.5
|
| 194 |
+
line_type = 1
|
| 195 |
+
for i, [x, y] in enumerate(xy):
|
| 196 |
+
if i < 4:
|
| 197 |
+
c = (0, 255, 0) # green
|
| 198 |
+
else:
|
| 199 |
+
c = color # cyan
|
| 200 |
+
x = int(round(x))
|
| 201 |
+
y = int(round(y))
|
| 202 |
+
if i >= 4:
|
| 203 |
+
cv2.circle(img, (x, y), 1, c, 1)
|
| 204 |
+
if score:
|
| 205 |
+
txt = str(scores[i - 4])
|
| 206 |
+
else:
|
| 207 |
+
txt = str(i + 1)
|
| 208 |
+
cv2.putText(img, txt, (x + 8, y), font,
|
| 209 |
+
font_scale, c, line_type)
|
| 210 |
+
else:
|
| 211 |
+
cv2.circle(img, (x, y), 1, c, 1)
|
| 212 |
+
cv2.putText(img, str(i + 1), (x + 8, y), font,
|
| 213 |
+
font_scale, c, line_type)
|
| 214 |
+
return img
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def adjust_xy(idx):
|
| 218 |
+
global xy, img_copy
|
| 219 |
+
key = cv2.waitKey(0) & 0xFF
|
| 220 |
+
xy = np.array(xy)
|
| 221 |
+
h, w = img_copy.shape[:2]
|
| 222 |
+
xy[:, 0] *= w; xy[:, 1] *= h
|
| 223 |
+
if key == 52: # one pixel left
|
| 224 |
+
if idx == -1:
|
| 225 |
+
xy[:, 0] -= 1
|
| 226 |
+
else:
|
| 227 |
+
xy[idx, 0] -= 1
|
| 228 |
+
if key == 56: # one pixel up
|
| 229 |
+
if idx == -1:
|
| 230 |
+
xy[:, 1] -= 1
|
| 231 |
+
else:
|
| 232 |
+
xy[idx, 1] -= 1
|
| 233 |
+
if key == 54: # one pixel right
|
| 234 |
+
if idx == -1:
|
| 235 |
+
xy[:, 0] += 1
|
| 236 |
+
else:
|
| 237 |
+
xy[idx, 0] += 1
|
| 238 |
+
if key == 50: # one pixel down
|
| 239 |
+
if idx == -1:
|
| 240 |
+
xy[:, 1] += 1
|
| 241 |
+
else:
|
| 242 |
+
xy[idx, 1] += 1
|
| 243 |
+
xy[:, 0] /= w; xy[:, 1] /= h
|
| 244 |
+
xy = xy.tolist()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def add_last_dart(annot, data_path, folder):
|
| 248 |
+
csv_path = osp.join(data_path, 'annotations', folder + '.csv')
|
| 249 |
+
if osp.isfile(csv_path):
|
| 250 |
+
dart_labels = []
|
| 251 |
+
csv = pd.read_csv(csv_path)
|
| 252 |
+
for idx in csv.index.values:
|
| 253 |
+
for c in csv.columns:
|
| 254 |
+
dart_labels.append(str(csv.loc[idx, c]))
|
| 255 |
+
annot['last_dart'] = dart_labels
|
| 256 |
+
return annot
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_bounding_box(img_path, scale=0.2):
|
| 260 |
+
img = cv2.imread(img_path)
|
| 261 |
+
img_resized = cv2.resize(img, None, fx=scale, fy=scale)
|
| 262 |
+
h, w = img_resized.shape[:2]
|
| 263 |
+
xy_bbox = []
|
| 264 |
+
|
| 265 |
+
def on_click_bbox(event, x, y, flags, param):
|
| 266 |
+
if event == cv2.EVENT_LBUTTONDOWN:
|
| 267 |
+
if len(xy_bbox) < 2:
|
| 268 |
+
xy_bbox.append([
|
| 269 |
+
round((x / w) * img.shape[1]),
|
| 270 |
+
round((y / h) * img.shape[0])])
|
| 271 |
+
|
| 272 |
+
window = 'get bbox'
|
| 273 |
+
cv2.namedWindow(window)
|
| 274 |
+
cv2.setMouseCallback(window, on_click_bbox)
|
| 275 |
+
while len(xy_bbox) < 2:
|
| 276 |
+
# print(xy_bbox)
|
| 277 |
+
cv2.imshow(window, img_resized)
|
| 278 |
+
key = cv2.waitKey(100)
|
| 279 |
+
if key == ord('q'): # quit
|
| 280 |
+
cv2.destroyAllWindows()
|
| 281 |
+
break
|
| 282 |
+
cv2.destroyAllWindows()
|
| 283 |
+
assert len(xy_bbox) == 2, 'click 2 points to get bounding box'
|
| 284 |
+
xy_bbox = np.array(xy_bbox)
|
| 285 |
+
# bbox = [y1 y2 x1 x2]
|
| 286 |
+
bbox = [min(xy_bbox[:, 1]), max(xy_bbox[:, 1]), min(xy_bbox[:, 0]), max(xy_bbox[:, 0])]
|
| 287 |
+
return bbox
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def main(cfg, folder, scale, draw_circles, dart_score=True):
|
| 291 |
+
global xy, img_copy
|
| 292 |
+
img_dir = osp.join(cfg.data.path, 'images', folder)
|
| 293 |
+
imgs = sorted(os.listdir(img_dir))
|
| 294 |
+
annot_path = osp.join(cfg.data.path, 'annotations', folder + '.pkl')
|
| 295 |
+
if osp.isfile(annot_path):
|
| 296 |
+
annot = pd.read_pickle(annot_path)
|
| 297 |
+
else:
|
| 298 |
+
annot = pd.DataFrame(columns=['img_name', 'bbox', 'xy'])
|
| 299 |
+
annot['img_name'] = imgs
|
| 300 |
+
annot['bbox'] = None
|
| 301 |
+
annot['xy'] = None
|
| 302 |
+
annot = add_last_dart(annot, cfg.data.path, folder)
|
| 303 |
+
|
| 304 |
+
i = 0
|
| 305 |
+
for j in range(len(annot)):
|
| 306 |
+
a = annot.iloc[j,:]
|
| 307 |
+
if a['bbox'] is not None:
|
| 308 |
+
i = j
|
| 309 |
+
|
| 310 |
+
while i < len(imgs):
|
| 311 |
+
xy = []
|
| 312 |
+
a = annot.iloc[i,:]
|
| 313 |
+
print('Annotating {}'.format(a['img_name']))
|
| 314 |
+
if a['bbox'] is None:
|
| 315 |
+
if i == 0:
|
| 316 |
+
bbox = get_bounding_box(osp.join(img_dir, a['img_name']))
|
| 317 |
+
if i > 0:
|
| 318 |
+
last_a = annot.iloc[i-1,:]
|
| 319 |
+
if last_a['xy'] is not None:
|
| 320 |
+
xy = last_a['xy'].copy()
|
| 321 |
+
else:
|
| 322 |
+
xy = []
|
| 323 |
+
else:
|
| 324 |
+
bbox, xy = a['bbox'], a['xy']
|
| 325 |
+
|
| 326 |
+
crop, _ = crop_board(osp.join(img_dir, a['img_name']), bbox=bbox)
|
| 327 |
+
crop = cv2.resize(crop, (int(crop.shape[1] * scale), int(crop.shape[0] * scale)))
|
| 328 |
+
cv2.putText(crop, '{}/{} {}'.format(i+1, len(annot), a['img_name']), (0, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
| 329 |
+
img_copy = crop.copy()
|
| 330 |
+
|
| 331 |
+
cv2.namedWindow(folder)
|
| 332 |
+
cv2.setMouseCallback(folder, on_click)
|
| 333 |
+
while True:
|
| 334 |
+
img_copy = draw(img_copy, xy, cfg, draw_circles, dart_score)
|
| 335 |
+
cv2.imshow(folder, img_copy)
|
| 336 |
+
key = cv2.waitKey(100) & 0xFF # update every 100 ms
|
| 337 |
+
|
| 338 |
+
if key == ord('q'): # quit
|
| 339 |
+
cv2.destroyAllWindows()
|
| 340 |
+
i = len(imgs)
|
| 341 |
+
break
|
| 342 |
+
|
| 343 |
+
if key == ord('b'): # draw new bounding box
|
| 344 |
+
idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
|
| 345 |
+
annot.at[idx, 'bbox'] = get_bounding_box(osp.join(img_dir, a['img_name']), scale)
|
| 346 |
+
break
|
| 347 |
+
|
| 348 |
+
if key == ord('.'):
|
| 349 |
+
i += 1
|
| 350 |
+
img_copy = crop.copy()
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
if key == ord(','):
|
| 354 |
+
if i > 0:
|
| 355 |
+
i += -1
|
| 356 |
+
img_copy = crop.copy()
|
| 357 |
+
break
|
| 358 |
+
|
| 359 |
+
if key == ord('z'): # undo keypoint
|
| 360 |
+
xy = xy[:-1]
|
| 361 |
+
img_copy = crop.copy()
|
| 362 |
+
|
| 363 |
+
if key == ord('x'): # reset annotation
|
| 364 |
+
idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
|
| 365 |
+
annot.at[idx, 'xy'] = None,
|
| 366 |
+
annot.at[idx, 'bbox'] = None
|
| 367 |
+
annot.to_pickle(annot_path)
|
| 368 |
+
break
|
| 369 |
+
|
| 370 |
+
if key == ord('d'): # delete img
|
| 371 |
+
print('Are you sure you want to delete this image? (y/n)')
|
| 372 |
+
key = cv2.waitKey(0) & 0xFF
|
| 373 |
+
if key == ord('y'):
|
| 374 |
+
idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
|
| 375 |
+
annot = annot.drop([idx])
|
| 376 |
+
annot.to_pickle(annot_path)
|
| 377 |
+
os.remove(osp.join(img_dir, a['img_name']))
|
| 378 |
+
print('Deleted image {}'.format(a['img_name']))
|
| 379 |
+
break
|
| 380 |
+
else:
|
| 381 |
+
print('Image not deleted.')
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
if key == ord('a'): # accept keypoints
|
| 385 |
+
idx = annot[(annot['img_name'] == a['img_name'])].index.values[0]
|
| 386 |
+
annot.at[idx, 'xy'] = xy
|
| 387 |
+
annot.at[idx, 'bbox'] = bbox
|
| 388 |
+
annot.to_pickle(annot_path)
|
| 389 |
+
i += 1
|
| 390 |
+
break
|
| 391 |
+
|
| 392 |
+
if key in [ord('1'), ord('2'), ord('3'), ord('4'), ord('5'), ord('6'), ord('7'), ord('0')]:
|
| 393 |
+
adjust_xy(idx=key - 49) # ord('1') = 49
|
| 394 |
+
img_copy = crop.copy()
|
| 395 |
+
continue
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == '__main__':
|
| 399 |
+
import sys
|
| 400 |
+
sys.path.append('../../')
|
| 401 |
+
parser = argparse.ArgumentParser()
|
| 402 |
+
parser.add_argument('-f', '--img-folder', default='d2_04_05_2020')
|
| 403 |
+
parser.add_argument('-s', '--scale', type=float, default=0.5)
|
| 404 |
+
parser.add_argument('-d', '--draw-circles', action='store_true')
|
| 405 |
+
args = parser.parse_args()
|
| 406 |
+
|
| 407 |
+
cfg = CN(new_allowed=True)
|
| 408 |
+
cfg.merge_from_file('../configs/tiny480_20e.yaml')
|
| 409 |
+
|
| 410 |
+
main(cfg, args.img_folder, args.scale, args.draw_circles)
|
dataset/labels.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/labels.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1bbde9f5cbfa1d623884c86210154867f99d3589309cc062476884952ac4c935
|
| 3 |
+
size 2791670
|
src/data.rs
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use burn::prelude::*;
|
| 2 |
+
use serde::{Deserialize, Serialize};
|
| 3 |
+
use std::collections::HashMap;
|
| 4 |
+
use std::fs::File;
|
| 5 |
+
use std::io::BufReader;
|
| 6 |
+
|
| 7 |
+
#[derive(Serialize, Deserialize, Debug, Clone)]
|
| 8 |
+
pub struct Annotation {
|
| 9 |
+
pub img_folder: String,
|
| 10 |
+
pub img_name: String,
|
| 11 |
+
pub bbox: Vec<i32>,
|
| 12 |
+
pub xy: Vec<Vec<f32>>,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
pub struct DartDataset {
|
| 16 |
+
pub annotations: Vec<Annotation>,
|
| 17 |
+
pub base_path: String,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
impl DartDataset {
|
| 21 |
+
pub fn load(json_path: &str, base_path: &str) -> Self {
|
| 22 |
+
let file = File::open(json_path).expect("Labels JSON not found");
|
| 23 |
+
let reader = BufReader::new(file);
|
| 24 |
+
let raw_data: HashMap<String, Annotation> = serde_json::from_reader(reader).expect("JSON parse error");
|
| 25 |
+
|
| 26 |
+
let mut annotations: Vec<Annotation> = raw_data.into_values().collect();
|
| 27 |
+
annotations.sort_by(|a, b| a.img_name.cmp(&b.img_name));
|
| 28 |
+
|
| 29 |
+
Self {
|
| 30 |
+
annotations,
|
| 31 |
+
base_path: base_path.to_string(),
|
| 32 |
+
}
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
impl burn::data::dataset::Dataset<Annotation> for DartDataset {
|
| 37 |
+
fn get(&self, index: usize) -> Option<Annotation> {
|
| 38 |
+
self.annotations.get(index).cloned()
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
fn len(&self) -> usize {
|
| 42 |
+
self.annotations.len()
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
#[derive(Clone, Debug)]
|
| 47 |
+
pub struct DartBatch<B: Backend> {
|
| 48 |
+
pub images: Tensor<B, 4>,
|
| 49 |
+
pub targets: Tensor<B, 4>,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
#[derive(Clone, Debug)]
|
| 53 |
+
pub struct DartBatcher<B: Backend> {
|
| 54 |
+
device: Device<B>,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
use burn::data::dataloader::batcher::Batcher;
|
| 58 |
+
|
| 59 |
+
impl<B: Backend> Batcher<Annotation, DartBatch<B>> for DartBatcher<B> {
|
| 60 |
+
fn batch(&self, items: Vec<Annotation>) -> DartBatch<B> {
|
| 61 |
+
self.batch_manual(items)
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
impl<B: Backend> DartBatcher<B> {
|
| 66 |
+
pub fn new(device: Device<B>) -> Self {
|
| 67 |
+
Self { device }
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
pub fn batch_manual(&self, items: Vec<Annotation>) -> DartBatch<B> {
|
| 71 |
+
let batch_size = items.len();
|
| 72 |
+
let input_res: usize = 416; // Standard YOLO 416 resolution for GPU stability
|
| 73 |
+
let grid_size: usize = 26; // 416 / 16 (stride accumulation) = 26
|
| 74 |
+
let num_channels: usize = 30; // 3 anchors * (x,y,w,h,obj,p0...p4)
|
| 75 |
+
|
| 76 |
+
let mut images_list = Vec::with_capacity(batch_size);
|
| 77 |
+
let mut target_raw = vec![0.0f32; batch_size * num_channels * grid_size * grid_size];
|
| 78 |
+
|
| 79 |
+
for (b_idx, item) in items.iter().enumerate() {
|
| 80 |
+
// 1. Process Image
|
| 81 |
+
let path = format!("dataset/800/{}/{}", item.img_folder, item.img_name);
|
| 82 |
+
let img = image::open(&path).unwrap_or_else(|_| image::DynamicImage::new_rgb8(input_res as u32, input_res as u32));
|
| 83 |
+
let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
|
| 84 |
+
let pixels: Vec<f32> = resized.to_rgb8().pixels()
|
| 85 |
+
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
|
| 86 |
+
.collect();
|
| 87 |
+
images_list.push(TensorData::new(pixels, [input_res, input_res, 3]));
|
| 88 |
+
|
| 89 |
+
for (i, p) in item.xy.iter().enumerate() {
|
| 90 |
+
let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 91 |
+
let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 92 |
+
|
| 93 |
+
let cls = if i < 4 { i + 1 } else { 0 };
|
| 94 |
+
let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
|
| 95 |
+
|
| 96 |
+
// TF order: [x,y,w,h,obj,p0..p4]
|
| 97 |
+
target_raw[base_idx + 0 * grid_size * grid_size] = p[0]; // X
|
| 98 |
+
target_raw[base_idx + 1 * grid_size * grid_size] = p[1]; // Y
|
| 99 |
+
target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
|
| 100 |
+
target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
|
| 101 |
+
target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
|
| 102 |
+
target_raw[base_idx + (5 + cls) * grid_size * grid_size] = 1.0; // Class prob
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
let images = Tensor::stack(
|
| 107 |
+
images_list.into_iter().map(|d| Tensor::<B, 3>::from_data(d, &self.device)).collect(),
|
| 108 |
+
0
|
| 109 |
+
).permute([0, 3, 1, 2]);
|
| 110 |
+
|
| 111 |
+
let targets = Tensor::from_data(
|
| 112 |
+
TensorData::new(target_raw, [batch_size, num_channels, grid_size, grid_size]),
|
| 113 |
+
&self.device
|
| 114 |
+
);
|
| 115 |
+
|
| 116 |
+
DartBatch { images, targets }
|
| 117 |
+
}
|
| 118 |
+
}
|
src/inference.rs
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use burn::module::Module;
|
| 2 |
+
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 3 |
+
use burn::tensor::backend::Backend;
|
| 4 |
+
use burn::tensor::{Tensor, TensorData};
|
| 5 |
+
use crate::model::DartVisionModel;
|
| 6 |
+
use image::{GenericImageView, DynamicImage};
|
| 7 |
+
|
| 8 |
+
pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
| 9 |
+
println!("🔍 Loading model for inference...");
|
| 10 |
+
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 11 |
+
let model: DartVisionModel<B> = DartVisionModel::new(device);
|
| 12 |
+
|
| 13 |
+
// Load weights
|
| 14 |
+
let record = Recorder::load(&recorder, "model_weights".into(), device)
|
| 15 |
+
.expect("Failed to load weights. Make sure model_weights.bin exists.");
|
| 16 |
+
let model = model.load_record(record);
|
| 17 |
+
|
| 18 |
+
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
+
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
+
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 21 |
+
let pixels: Vec<f32> = resized.to_rgb8().pixels()
|
| 22 |
+
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
|
| 23 |
+
.collect();
|
| 24 |
+
|
| 25 |
+
let data = TensorData::new(pixels, [800, 800, 3]);
|
| 26 |
+
let input = Tensor::<B, 3>::from_data(data, device).unsqueeze::<4>().permute([0, 3, 1, 2]);
|
| 27 |
+
|
| 28 |
+
println!("🚀 Running MODEL Prediction...");
|
| 29 |
+
let (out16, _out32) = model.forward(input);
|
| 30 |
+
|
| 31 |
+
// Post-process out16 (size [1, 30, 100, 100])
|
| 32 |
+
// Decode objectness part (Channel 4 for Anchor 0)
|
| 33 |
+
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 34 |
+
|
| 35 |
+
// Find highest confidence cell
|
| 36 |
+
let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
|
| 37 |
+
let confidence: f32 = max_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 38 |
+
|
| 39 |
+
println!("--------------------------------------------------");
|
| 40 |
+
println!("📊 RESULTS FOR: {}", image_path);
|
| 41 |
+
println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
|
| 42 |
+
|
| 43 |
+
if confidence > 0.05 {
|
| 44 |
+
println!("✅ Model found something! Confidence Score: {:.4}", confidence);
|
| 45 |
+
} else {
|
| 46 |
+
println!("⚠️ Model confidence is too low. Training incomplete?");
|
| 47 |
+
}
|
| 48 |
+
println!("--------------------------------------------------");
|
| 49 |
+
}
|
src/loss.rs
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use burn::tensor::backend::Backend;
|
| 2 |
+
use burn::tensor::Tensor;
|
| 3 |
+
|
| 4 |
+
pub fn diou_loss<B: Backend>(
|
| 5 |
+
bboxes_pred: Tensor<B, 4>,
|
| 6 |
+
target: Tensor<B, 4>,
|
| 7 |
+
) -> Tensor<B, 1> {
|
| 8 |
+
// 1. Reshape to separate anchors: [Batch, 3, 10, H, W]
|
| 9 |
+
let [batch, _channels, h, w] = bboxes_pred.dims();
|
| 10 |
+
let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
|
| 11 |
+
let t = target.reshape([batch, 3, 10, h, w]);
|
| 12 |
+
|
| 13 |
+
// 2. Objectness (Channel 4)
|
| 14 |
+
let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
|
| 15 |
+
let obj_target = t.clone().narrow(2, 4, 1);
|
| 16 |
+
|
| 17 |
+
let eps = 1e-7;
|
| 18 |
+
// Positive loss (where an object exists)
|
| 19 |
+
let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
|
| 20 |
+
// Negative loss (where no object exists)
|
| 21 |
+
let neg_loss = obj_target.clone().neg().add_scalar(1.0).mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
|
| 22 |
+
|
| 23 |
+
// Weight positive samples 10x more to fight imbalance (typical YOLO trick)
|
| 24 |
+
let obj_loss = pos_loss.mul_scalar(20.0).add(neg_loss).mean();
|
| 25 |
+
|
| 26 |
+
// 3. Class (Channels 5-9) - Only learn when object exists
|
| 27 |
+
let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
|
| 28 |
+
let cls_target = t.clone().narrow(2, 5, 5);
|
| 29 |
+
let class_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg()
|
| 30 |
+
.mul(obj_target.clone()) // Only count where object exists
|
| 31 |
+
.mean()
|
| 32 |
+
.mul_scalar(5.0); // Boost class learning
|
| 33 |
+
|
| 34 |
+
// 4. Coordinates (Channels 0-3) - Only learn when object exists
|
| 35 |
+
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
|
| 36 |
+
let b_xy_target = t.clone().narrow(2, 0, 2);
|
| 37 |
+
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(5.0);
|
| 38 |
+
|
| 39 |
+
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
|
| 40 |
+
let b_wh_target = t.clone().narrow(2, 2, 2);
|
| 41 |
+
let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0).mul(obj_target).mean().mul_scalar(5.0);
|
| 42 |
+
|
| 43 |
+
obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
|
| 44 |
+
}
|
src/main.rs
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use crate::model::DartVisionModel;
|
| 2 |
+
use crate::server::start_gui;
|
| 3 |
+
use crate::train::{train, TrainingConfig};
|
| 4 |
+
use burn::backend::wgpu::WgpuDevice;
|
| 5 |
+
use burn::backend::Wgpu;
|
| 6 |
+
use burn::prelude::*;
|
| 7 |
+
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 8 |
+
|
| 9 |
+
pub mod data;
|
| 10 |
+
pub mod loss;
|
| 11 |
+
pub mod model;
|
| 12 |
+
pub mod scoring;
|
| 13 |
+
pub mod server;
|
| 14 |
+
pub mod train;
|
| 15 |
+
|
| 16 |
+
fn main() {
|
| 17 |
+
let device = WgpuDevice::default();
|
| 18 |
+
let args: Vec<String> = std::env::args().collect();
|
| 19 |
+
|
| 20 |
+
if args.len() > 1 && args[1] == "gui" {
|
| 21 |
+
println!("🌐 [Burn-DartVision] Starting Professional Dashboard...");
|
| 22 |
+
tokio::runtime::Builder::new_multi_thread()
|
| 23 |
+
.enable_all()
|
| 24 |
+
.build()
|
| 25 |
+
.unwrap()
|
| 26 |
+
.block_on(start_gui(device));
|
| 27 |
+
} else if args.len() > 1 && args[1] == "test" {
|
| 28 |
+
let img_path = if args.len() > 2 { &args[2] } else { "test.jpg" };
|
| 29 |
+
test_model(device, img_path);
|
| 30 |
+
} else {
|
| 31 |
+
println!("🚀 [Burn-DartVision] Starting Full Project Training...");
|
| 32 |
+
let dataset_path = "dataset/labels.json";
|
| 33 |
+
|
| 34 |
+
let config = TrainingConfig {
|
| 35 |
+
num_epochs: 10,
|
| 36 |
+
batch_size: 1,
|
| 37 |
+
lr: 1e-3,
|
| 38 |
+
};
|
| 39 |
+
|
| 40 |
+
train::<burn::backend::Autodiff<Wgpu>>(device, dataset_path, config);
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
fn test_model(device: WgpuDevice, img_path: &str) {
|
| 45 |
+
println!("🔍 Testing model on: {}", img_path);
|
| 46 |
+
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 47 |
+
let model = DartVisionModel::<Wgpu>::new(&device);
|
| 48 |
+
|
| 49 |
+
let record = match recorder.load("model_weights".into(), &device) {
|
| 50 |
+
Ok(r) => r,
|
| 51 |
+
Err(_) => {
|
| 52 |
+
println!("⚠️ Weights not found, using initial model.");
|
| 53 |
+
model.clone().into_record()
|
| 54 |
+
}
|
| 55 |
+
};
|
| 56 |
+
let model = model.load_record(record);
|
| 57 |
+
|
| 58 |
+
let img = image::open(img_path).unwrap_or_else(|_| {
|
| 59 |
+
println!("❌ Image not found at {}. Using random tensor.", img_path);
|
| 60 |
+
image::DynamicImage::new_rgb8(416, 416)
|
| 61 |
+
});
|
| 62 |
+
let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
|
| 63 |
+
let pixels: Vec<f32> = resized
|
| 64 |
+
.to_rgb8()
|
| 65 |
+
.pixels()
|
| 66 |
+
.flat_map(|p| {
|
| 67 |
+
vec![
|
| 68 |
+
p[0] as f32 / 255.0,
|
| 69 |
+
p[1] as f32 / 255.0,
|
| 70 |
+
p[2] as f32 / 255.0,
|
| 71 |
+
]
|
| 72 |
+
})
|
| 73 |
+
.collect();
|
| 74 |
+
|
| 75 |
+
let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
|
| 76 |
+
let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &device).permute([0, 3, 1, 2]);
|
| 77 |
+
let (out, _): (Tensor<Wgpu, 4>, _) = model.forward(input);
|
| 78 |
+
|
| 79 |
+
let obj = burn::tensor::activation::sigmoid(out.clone().narrow(1, 4, 1));
|
| 80 |
+
let (max_val, _) = obj.reshape([1, 676]).max_dim_with_indices(1);
|
| 81 |
+
|
| 82 |
+
let score = max_val
|
| 83 |
+
.to_data()
|
| 84 |
+
.convert::<f32>()
|
| 85 |
+
.as_slice::<f32>()
|
| 86 |
+
.unwrap()[0];
|
| 87 |
+
println!("📊 Max Objectness Score: {:.6}", score);
|
| 88 |
+
}
|
src/model.rs
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use burn::nn::conv::{Conv2d, Conv2dConfig};
|
| 2 |
+
use burn::nn::{BatchNorm, BatchNormConfig};
|
| 3 |
+
use burn::module::Module;
|
| 4 |
+
use burn::tensor::backend::Backend;
|
| 5 |
+
use burn::tensor::Tensor;
|
| 6 |
+
use burn::nn::PaddingConfig2d;
|
| 7 |
+
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
|
| 8 |
+
|
| 9 |
+
#[derive(Module, Debug)]
|
| 10 |
+
pub struct ConvBlock<B: Backend> {
|
| 11 |
+
conv: Conv2d<B>,
|
| 12 |
+
bn: BatchNorm<B, 2>,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
impl<B: Backend> ConvBlock<B> {
|
| 16 |
+
pub fn new(in_channels: usize, out_channels: usize, kernel_size: [usize; 2], device: &B::Device) -> Self {
|
| 17 |
+
let config = Conv2dConfig::new([in_channels, out_channels], kernel_size)
|
| 18 |
+
.with_padding(PaddingConfig2d::Same);
|
| 19 |
+
let conv = config.init(device);
|
| 20 |
+
let bn = BatchNormConfig::new(out_channels).init(device);
|
| 21 |
+
Self { conv, bn }
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
| 25 |
+
let x = self.conv.forward(x);
|
| 26 |
+
let x = self.bn.forward(x);
|
| 27 |
+
burn::tensor::activation::leaky_relu(x, 0.1)
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
#[derive(Module, Debug)]
|
| 32 |
+
pub struct DartVisionModel<B: Backend> {
|
| 33 |
+
// Lean architecture: High resolution (800x800) but low channel count to fix GPU OOM
|
| 34 |
+
l1: ConvBlock<B>, // 3 -> 16
|
| 35 |
+
p1: MaxPool2d,
|
| 36 |
+
l2: ConvBlock<B>, // 16 -> 16
|
| 37 |
+
p2: MaxPool2d,
|
| 38 |
+
l3: ConvBlock<B>, // 16 -> 32
|
| 39 |
+
p3: MaxPool2d,
|
| 40 |
+
l4: ConvBlock<B>, // 32 -> 32
|
| 41 |
+
p4: MaxPool2d,
|
| 42 |
+
l5: ConvBlock<B>, // 32 -> 64
|
| 43 |
+
l6: ConvBlock<B>, // 64 -> 64
|
| 44 |
+
|
| 45 |
+
head_32: Conv2d<B>, // Final detection head
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
impl<B: Backend> DartVisionModel<B> {
|
| 49 |
+
pub fn new(device: &B::Device) -> Self {
|
| 50 |
+
let l1 = ConvBlock::new(3, 16, [3, 3], device);
|
| 51 |
+
let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 52 |
+
|
| 53 |
+
let l2 = ConvBlock::new(16, 16, [3, 3], device);
|
| 54 |
+
let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 55 |
+
|
| 56 |
+
let l3 = ConvBlock::new(16, 32, [3, 3], device);
|
| 57 |
+
let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 58 |
+
|
| 59 |
+
let l4 = ConvBlock::new(32, 32, [3, 3], device);
|
| 60 |
+
let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 61 |
+
|
| 62 |
+
let l5 = ConvBlock::new(32, 64, [3, 3], device);
|
| 63 |
+
let l6 = ConvBlock::new(64, 64, [3, 3], device);
|
| 64 |
+
|
| 65 |
+
// 30 channels = 3 anchors * (x,y,w,h,obj,p0...p4)
|
| 66 |
+
let head_32 = Conv2dConfig::new([64, 30], [1, 1]).init(device);
|
| 67 |
+
|
| 68 |
+
Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4>) {
|
| 72 |
+
let x = self.l1.forward(x); // 800
|
| 73 |
+
let x = self.p1.forward(x); // 400
|
| 74 |
+
let x = self.l2.forward(x); // 400
|
| 75 |
+
let x = self.p2.forward(x); // 200
|
| 76 |
+
let x = self.l3.forward(x); // 200
|
| 77 |
+
let x = self.p3.forward(x); // 100
|
| 78 |
+
let x = self.l4.forward(x); // 100
|
| 79 |
+
let x = self.p4.forward(x); // 50
|
| 80 |
+
|
| 81 |
+
let x50 = self.l5.forward(x); // 50
|
| 82 |
+
let x50 = self.l6.forward(x50); // 50
|
| 83 |
+
|
| 84 |
+
let out50 = self.head_32.forward(x50);
|
| 85 |
+
|
| 86 |
+
(out50.clone(), out50)
|
| 87 |
+
}
|
| 88 |
+
}
|
src/scoring.rs
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use std::collections::HashMap;
|
| 2 |
+
|
| 3 |
+
pub struct ScoringConfig {
|
| 4 |
+
pub r_double: f32,
|
| 5 |
+
pub r_treble: f32,
|
| 6 |
+
pub r_outer_bull: f32,
|
| 7 |
+
pub r_inner_bull: f32,
|
| 8 |
+
pub w_double_treble: f32,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
impl Default for ScoringConfig {
|
| 12 |
+
fn default() -> Self {
|
| 13 |
+
Self {
|
| 14 |
+
r_double: 0.170,
|
| 15 |
+
r_treble: 0.1074,
|
| 16 |
+
r_outer_bull: 0.0159,
|
| 17 |
+
r_inner_bull: 0.00635,
|
| 18 |
+
w_double_treble: 0.01,
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
pub fn get_board_dict() -> HashMap<i32, &'static str> {
|
| 24 |
+
let mut m = HashMap::new();
|
| 25 |
+
// BDO standard mapping based on degrees
|
| 26 |
+
let slices = ["13", "4", "18", "1", "20", "5", "12", "9", "14", "11", "8", "16", "7", "19", "3", "17", "2", "15", "10", "6"];
|
| 27 |
+
for (i, &s) in slices.iter().enumerate() {
|
| 28 |
+
m.insert(i as i32, s);
|
| 29 |
+
}
|
| 30 |
+
m
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
pub fn calculate_dart_score(cal_pts: &[[f32; 2]], dart_pt: &[f32; 2], config: &ScoringConfig) -> (i32, String) {
|
| 34 |
+
// 1. Calculate Center (Average of 4 calibration points)
|
| 35 |
+
let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
|
| 36 |
+
let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
|
| 37 |
+
|
| 38 |
+
// 2. Calculate average radius to boundary (doubles wire)
|
| 39 |
+
let avg_r_px = cal_pts.iter()
|
| 40 |
+
.map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
|
| 41 |
+
.sum::<f32>() / 4.0;
|
| 42 |
+
|
| 43 |
+
// 3. Relative distance of dart from center
|
| 44 |
+
let dx = dart_pt[0] - cx;
|
| 45 |
+
let dy = dart_pt[1] - cy;
|
| 46 |
+
let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
|
| 47 |
+
|
| 48 |
+
// Scale distance relative to BDO double radius
|
| 49 |
+
let dist_scaled = (dist_px / avg_r_px) * config.r_double;
|
| 50 |
+
|
| 51 |
+
// 4. Calculate Angle (0 is 3 o'clock, CCW)
|
| 52 |
+
let mut angle_deg = (-dy).atan2(dx).to_degrees();
|
| 53 |
+
if angle_deg < 0.0 { angle_deg += 360.0; }
|
| 54 |
+
|
| 55 |
+
// Board is rotated such that 20 is at top (90 deg)
|
| 56 |
+
// Sector width is 18 deg. Sector 20 is centered at 90 deg.
|
| 57 |
+
// 90 deg is index 4 in slices (13, 4, 18, 1, 20...)
|
| 58 |
+
// Each index is 18 deg. Offset = 4 * 18 = 72? No.
|
| 59 |
+
// Let's use the standard mapping: (angle / 18)
|
| 60 |
+
// Wait, the BOARD_DICT in Python uses int(angle / 18) where angle is 0-360.
|
| 61 |
+
// We need to match the slice orientation.
|
| 62 |
+
let board_dict = get_board_dict();
|
| 63 |
+
let sector_idx = ((angle_deg / 18.0).floor() as i32) % 20;
|
| 64 |
+
let sector_num = board_dict.get(§or_idx).unwrap_or(&"0");
|
| 65 |
+
|
| 66 |
+
// 5. Determine multipliers based on scaled distance
|
| 67 |
+
let r_t = config.r_treble;
|
| 68 |
+
let r_d = config.r_double;
|
| 69 |
+
let w = config.w_double_treble;
|
| 70 |
+
let r_ib = config.r_inner_bull;
|
| 71 |
+
let r_ob = config.r_outer_bull;
|
| 72 |
+
|
| 73 |
+
if dist_scaled > r_d {
|
| 74 |
+
(0, "Miss".to_string())
|
| 75 |
+
} else if dist_scaled <= r_ib {
|
| 76 |
+
(50, "DB".to_string())
|
| 77 |
+
} else if dist_scaled <= r_ob {
|
| 78 |
+
(25, "B".to_string())
|
| 79 |
+
} else if dist_scaled <= r_d && dist_scaled > (r_d - w) {
|
| 80 |
+
let val = sector_num.parse::<i32>().unwrap_or(0);
|
| 81 |
+
(val * 2, format!("D{}", sector_num))
|
| 82 |
+
} else if dist_scaled <= r_t && dist_scaled > (r_t - w) {
|
| 83 |
+
let val = sector_num.parse::<i32>().unwrap_or(0);
|
| 84 |
+
(val * 3, format!("T{}", sector_num))
|
| 85 |
+
} else {
|
| 86 |
+
let val = sector_num.parse::<i32>().unwrap_or(0);
|
| 87 |
+
(val, sector_num.to_string())
|
| 88 |
+
}
|
| 89 |
+
}
|
src/server.rs
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use axum::{
|
| 2 |
+
extract::{DefaultBodyLimit, Multipart, State},
|
| 3 |
+
response::{Html, Json},
|
| 4 |
+
routing::{get, post},
|
| 5 |
+
Router,
|
| 6 |
+
};
|
| 7 |
+
use burn::prelude::*;
|
| 8 |
+
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 9 |
+
use burn::backend::Wgpu;
|
| 10 |
+
use burn::backend::wgpu::WgpuDevice;
|
| 11 |
+
use crate::model::DartVisionModel;
|
| 12 |
+
use serde_json::json;
|
| 13 |
+
use std::net::SocketAddr;
|
| 14 |
+
use std::sync::Arc;
|
| 15 |
+
use tokio::sync::{mpsc, oneshot};
|
| 16 |
+
use tower_http::cors::CorsLayer;
|
| 17 |
+
|
| 18 |
+
#[derive(Debug)]
|
| 19 |
+
struct PredictResult {
|
| 20 |
+
confidence: f32,
|
| 21 |
+
keypoints: Vec<f32>,
|
| 22 |
+
scores: Vec<String>,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
struct PredictRequest {
|
| 26 |
+
image_bytes: Vec<u8>,
|
| 27 |
+
response_tx: oneshot::Sender<PredictResult>,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
pub async fn start_gui(device: WgpuDevice) {
|
| 31 |
+
let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
|
| 32 |
+
println!("🚀 [DartVision-GUI] Starting on http://127.0.0.1:8080",);
|
| 33 |
+
|
| 34 |
+
let (tx, mut rx) = mpsc::channel::<PredictRequest>(10);
|
| 35 |
+
|
| 36 |
+
let worker_device = device.clone();
|
| 37 |
+
std::thread::spawn(move || {
|
| 38 |
+
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 39 |
+
let model = DartVisionModel::<Wgpu>::new(&worker_device);
|
| 40 |
+
let record = match recorder.load("model_weights".into(), &worker_device) {
|
| 41 |
+
Ok(r) => r,
|
| 42 |
+
Err(_) => {
|
| 43 |
+
println!("⚠️ [DartVision] No 'model_weights.bin' yet. Using initial weights...");
|
| 44 |
+
model.clone().into_record()
|
| 45 |
+
}
|
| 46 |
+
};
|
| 47 |
+
let model = model.load_record(record);
|
| 48 |
+
|
| 49 |
+
while let Some(req) = rx.blocking_recv() {
|
| 50 |
+
let img = image::load_from_memory(&req.image_bytes).unwrap();
|
| 51 |
+
let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
|
| 52 |
+
let pixels: Vec<f32> = resized.to_rgb8().pixels()
|
| 53 |
+
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
|
| 54 |
+
.collect();
|
| 55 |
+
|
| 56 |
+
let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
|
| 57 |
+
let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
|
| 58 |
+
|
| 59 |
+
let (out16, _) = model.forward(input);
|
| 60 |
+
|
| 61 |
+
let mut final_points = vec![0.0f32; 8]; // 4 corners
|
| 62 |
+
let mut max_conf = 0.0f32;
|
| 63 |
+
|
| 64 |
+
// 1. Extract Objectness with Sigmoid
|
| 65 |
+
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 66 |
+
|
| 67 |
+
// 2. Extract best calibration corner for each class 1 to 4 (Grid 26x26 = 676)
|
| 68 |
+
for cls_idx in 1..=4 {
|
| 69 |
+
let prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5 + cls_idx, 1));
|
| 70 |
+
let score = obj.clone().mul(prob);
|
| 71 |
+
let (val, idx) = score.reshape([1, 676]).max_dim_with_indices(1);
|
| 72 |
+
|
| 73 |
+
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 74 |
+
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 75 |
+
let gy = f_idx / 26;
|
| 76 |
+
let gx = f_idx % 26;
|
| 77 |
+
|
| 78 |
+
// Use Sigmoid for Coordinates (matching new loss logic)
|
| 79 |
+
let px = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
|
| 80 |
+
.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 81 |
+
let py = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
|
| 82 |
+
.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 83 |
+
|
| 84 |
+
final_points[(cls_idx-1)*2] = px;
|
| 85 |
+
final_points[(cls_idx-1)*2+1] = py;
|
| 86 |
+
if s > max_conf { max_conf = s; }
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
// 3. Extract best dart (Class 0)
|
| 90 |
+
let d_prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5, 1));
|
| 91 |
+
let d_score = obj.clone().mul(d_prob);
|
| 92 |
+
let (d_val, d_idx) = d_score.reshape([1, 676]).max_dim_with_indices(1);
|
| 93 |
+
let ds = d_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 94 |
+
if ds > 0.1 {
|
| 95 |
+
let f_idx = d_idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 96 |
+
let gy = f_idx / 26;
|
| 97 |
+
let gx = f_idx % 26;
|
| 98 |
+
let dx = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
|
| 99 |
+
.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 100 |
+
let dy = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
|
| 101 |
+
.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 102 |
+
final_points.push(dx);
|
| 103 |
+
final_points.push(dy);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
let mut final_scores = vec![];
|
| 107 |
+
|
| 108 |
+
// Calculate scores if we have calibration points and at least one dart
|
| 109 |
+
if final_points.len() >= 10 {
|
| 110 |
+
use crate::scoring::{calculate_dart_score, ScoringConfig};
|
| 111 |
+
let config = ScoringConfig::default();
|
| 112 |
+
let cal_pts = [
|
| 113 |
+
[final_points[0], final_points[1]],
|
| 114 |
+
[final_points[2], final_points[3]],
|
| 115 |
+
[final_points[4], final_points[5]],
|
| 116 |
+
[final_points[6], final_points[7]],
|
| 117 |
+
];
|
| 118 |
+
|
| 119 |
+
for dart_chunk in final_points[8..].chunks(2) {
|
| 120 |
+
if dart_chunk.len() == 2 {
|
| 121 |
+
let dart_pt = [dart_chunk[0], dart_chunk[1]];
|
| 122 |
+
let (_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
|
| 123 |
+
final_scores.push(label);
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
println!("🎯 [Detection Result] Confidence: {:.2}%", max_conf * 100.0);
|
| 129 |
+
let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
|
| 130 |
+
for (i, pts) in final_points.chunks(2).enumerate() {
|
| 131 |
+
let name = class_names.get(i).unwrap_or(&"Dart");
|
| 132 |
+
let label = final_scores.get(i.saturating_sub(4)).cloned().unwrap_or_default();
|
| 133 |
+
println!(" - {}: [x: {:.3}, y: {:.3}] {}", name, pts[0], pts[1], label);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
let _ = req.response_tx.send(PredictResult {
|
| 137 |
+
confidence: max_conf,
|
| 138 |
+
keypoints: final_points,
|
| 139 |
+
scores: final_scores,
|
| 140 |
+
});
|
| 141 |
+
}
|
| 142 |
+
});
|
| 143 |
+
|
| 144 |
+
let state = Arc::new(tx);
|
| 145 |
+
|
| 146 |
+
let app = Router::new()
|
| 147 |
+
.route("/", get(|| async { Html(include_str!("../static/index.html")) }))
|
| 148 |
+
.route("/api/predict", post(predict_handler))
|
| 149 |
+
.with_state(state)
|
| 150 |
+
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
|
| 151 |
+
.layer(CorsLayer::permissive());
|
| 152 |
+
|
| 153 |
+
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
|
| 154 |
+
axum::serve(listener, app).await.unwrap();
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
async fn predict_handler(
|
| 158 |
+
State(tx): State<Arc<mpsc::Sender<PredictRequest>>>,
|
| 159 |
+
mut multipart: Multipart,
|
| 160 |
+
) -> Json<serde_json::Value> {
|
| 161 |
+
while let Ok(Some(field)) = multipart.next_field().await {
|
| 162 |
+
if field.name() == Some("image") {
|
| 163 |
+
let bytes = match field.bytes().await {
|
| 164 |
+
Ok(b) => b.to_vec(),
|
| 165 |
+
Err(_) => continue,
|
| 166 |
+
};
|
| 167 |
+
let (res_tx, res_rx) = oneshot::channel();
|
| 168 |
+
let _ = tx.send(PredictRequest { image_bytes: bytes, response_tx: res_tx }).await;
|
| 169 |
+
let result = res_rx.await.unwrap_or(PredictResult { confidence: 0.0, keypoints: vec![] });
|
| 170 |
+
|
| 171 |
+
return Json(json!({
|
| 172 |
+
"status": "success",
|
| 173 |
+
"confidence": result.confidence,
|
| 174 |
+
"keypoints": result.keypoints,
|
| 175 |
+
"scores": result.scores,
|
| 176 |
+
"message": if result.confidence > 0.1 {
|
| 177 |
+
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
|
| 178 |
+
} else {
|
| 179 |
+
"⚠️ Low confidence detection - no dart score could be verified.".to_string()
|
| 180 |
+
}
|
| 181 |
+
}));
|
| 182 |
+
}
|
| 183 |
+
}
|
| 184 |
+
Json(json!({"status": "error", "message": "No image field found"}))
|
| 185 |
+
}
|
src/train.rs
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
use crate::data::{DartBatcher, DartDataset};
|
| 2 |
+
use crate::loss::diou_loss;
|
| 3 |
+
use crate::model::DartVisionModel;
|
| 4 |
+
use burn::data::dataset::Dataset; // Add this trait to scope
|
| 5 |
+
use burn::module::Module;
|
| 6 |
+
use burn::optim::{AdamConfig, GradientsParams, Optimizer};
|
| 7 |
+
use burn::prelude::*;
|
| 8 |
+
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 9 |
+
use burn::tensor::backend::AutodiffBackend;
|
| 10 |
+
|
| 11 |
+
pub struct TrainingConfig {
|
| 12 |
+
pub num_epochs: usize,
|
| 13 |
+
pub batch_size: usize,
|
| 14 |
+
pub lr: f64,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config: TrainingConfig) {
|
| 18 |
+
// 1. Create Model
|
| 19 |
+
let mut model: DartVisionModel<B> = DartVisionModel::new(&device);
|
| 20 |
+
|
| 21 |
+
// 1.5 Load existing weights if they exist (RESUME)
|
| 22 |
+
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 23 |
+
let weights_path = "model_weights.bin";
|
| 24 |
+
if std::path::Path::new(weights_path).exists() {
|
| 25 |
+
println!("🚀 Loading existing weights from {}...", weights_path);
|
| 26 |
+
let record = Recorder::load(&recorder, "model_weights".into(), &device)
|
| 27 |
+
.expect("Failed to load weights");
|
| 28 |
+
model = model.load_record(record);
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// 2. Setup Optimizer
|
| 32 |
+
let mut optim = AdamConfig::new().init();
|
| 33 |
+
|
| 34 |
+
// 3. Create Dataset
|
| 35 |
+
println!("🔍 Mapping annotations from {}...", dataset_path);
|
| 36 |
+
let dataset = DartDataset::load(dataset_path, "dataset/800");
|
| 37 |
+
println!("📊 Dataset loaded with {} examples.", dataset.len());
|
| 38 |
+
let batcher = DartBatcher::new(device.clone());
|
| 39 |
+
|
| 40 |
+
// 4. Create DataLoader
|
| 41 |
+
println!("📦 Initializing DataLoader (Workers: 4)...");
|
| 42 |
+
let dataloader = burn::data::dataloader::DataLoaderBuilder::new(batcher)
|
| 43 |
+
.batch_size(config.batch_size)
|
| 44 |
+
.shuffle(42)
|
| 45 |
+
.num_workers(4)
|
| 46 |
+
.build(dataset);
|
| 47 |
+
|
| 48 |
+
// 5. Training Loop
|
| 49 |
+
println!(
|
| 50 |
+
"📈 Running FULL Training Loop (Epochs: {})...",
|
| 51 |
+
config.num_epochs
|
| 52 |
+
);
|
| 53 |
+
|
| 54 |
+
// Using a simple loop state for ownership safety
|
| 55 |
+
let mut current_model = model;
|
| 56 |
+
|
| 57 |
+
for epoch in 1..=config.num_epochs {
|
| 58 |
+
let mut model_inner = current_model; // Move into epoch
|
| 59 |
+
let mut batch_count = 0;
|
| 60 |
+
|
| 61 |
+
for batch in dataloader.iter() {
|
| 62 |
+
// Forward Pass
|
| 63 |
+
let (out16, _) = model_inner.forward(batch.images);
|
| 64 |
+
|
| 65 |
+
// Calculate Loss
|
| 66 |
+
let loss = diou_loss(out16, batch.targets);
|
| 67 |
+
batch_count += 1;
|
| 68 |
+
|
| 69 |
+
// Print every 10 batches to keep terminal clean and avoid stdout sync lag
|
| 70 |
+
if batch_count % 20 == 0 || batch_count == 1 {
|
| 71 |
+
println!(
|
| 72 |
+
" [Epoch {}] Batch {: >3} | Loss: {:.6}",
|
| 73 |
+
epoch,
|
| 74 |
+
batch_count,
|
| 75 |
+
loss.clone().into_scalar()
|
| 76 |
+
);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// Backward & Optimization step
|
| 80 |
+
let grads = loss.backward();
|
| 81 |
+
let grads_params = GradientsParams::from_grads(grads, &model_inner);
|
| 82 |
+
model_inner = optim.step(config.lr, model_inner, grads_params);
|
| 83 |
+
|
| 84 |
+
// 5.5 Periodic Save (every 100 batches and Batch 1)
|
| 85 |
+
if batch_count % 100 == 0 || batch_count == 1 {
|
| 86 |
+
model_inner.clone()
|
| 87 |
+
.save_file("model_weights", &recorder)
|
| 88 |
+
.ok();
|
| 89 |
+
if batch_count == 1 {
|
| 90 |
+
println!("🚀 [Checkpoint] Initial weights saved at Batch 1.");
|
| 91 |
+
} else {
|
| 92 |
+
println!("🚀 [Checkpoint] Saved at Batch {}.", batch_count);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// 6. SAVE after EACH Epoch
|
| 98 |
+
model_inner
|
| 99 |
+
.clone()
|
| 100 |
+
.save_file("model_weights", &recorder)
|
| 101 |
+
.expect("Failed to save weights");
|
| 102 |
+
println!("✅ Checkpoint saved: Epoch {} complete.", epoch);
|
| 103 |
+
|
| 104 |
+
current_model = model_inner; // Move back out for next epoch
|
| 105 |
+
}
|
| 106 |
+
}
|
static/index.html
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>DartVision AI - Smart Scoring Dashboard</title>
|
| 7 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 8 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 9 |
+
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
|
| 10 |
+
<style>
|
| 11 |
+
:root {
|
| 12 |
+
--primary: #00ff88;
|
| 13 |
+
--secondary: #00d4ff;
|
| 14 |
+
--accent: #ff4d4d;
|
| 15 |
+
--bg: #0a0e14;
|
| 16 |
+
--card-bg: rgba(255, 255, 255, 0.05);
|
| 17 |
+
--glass-border: rgba(255, 255, 255, 0.1);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
* {
|
| 21 |
+
margin: 0;
|
| 22 |
+
padding: 0;
|
| 23 |
+
box-sizing: border-box;
|
| 24 |
+
font-family: 'Outfit', sans-serif;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
body {
|
| 28 |
+
background: var(--bg);
|
| 29 |
+
color: #fff;
|
| 30 |
+
min-height: 100vh;
|
| 31 |
+
display: flex;
|
| 32 |
+
flex-direction: column;
|
| 33 |
+
align-items: center;
|
| 34 |
+
overflow-x: hidden;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.bg-glow {
|
| 38 |
+
position: fixed;
|
| 39 |
+
top: 0;
|
| 40 |
+
left: 0;
|
| 41 |
+
right: 0;
|
| 42 |
+
bottom: 0;
|
| 43 |
+
background: radial-gradient(circle at 50% 50%, #00ff8811 0%, transparent 50%),
|
| 44 |
+
radial-gradient(circle at 80% 20%, #00d4ff11 0%, transparent 40%);
|
| 45 |
+
z-index: -1;
|
| 46 |
+
pointer-events: none;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
header {
|
| 50 |
+
padding: 2rem;
|
| 51 |
+
text-align: center;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
h1 {
|
| 55 |
+
font-size: 3.5rem;
|
| 56 |
+
font-weight: 800;
|
| 57 |
+
background: linear-gradient(to right, var(--primary), var(--secondary));
|
| 58 |
+
-webkit-background-clip: text;
|
| 59 |
+
background-clip: text;
|
| 60 |
+
-webkit-text-fill-color: transparent;
|
| 61 |
+
margin-bottom: 0.5rem;
|
| 62 |
+
letter-spacing: -1px;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
p.subtitle {
|
| 66 |
+
color: #8892b0;
|
| 67 |
+
font-size: 1.1rem;
|
| 68 |
+
letter-spacing: 2px;
|
| 69 |
+
text-transform: uppercase;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
main {
|
| 73 |
+
width: 90%;
|
| 74 |
+
max-width: 1200px;
|
| 75 |
+
display: grid;
|
| 76 |
+
grid-template-columns: 1fr 380px;
|
| 77 |
+
gap: 2rem;
|
| 78 |
+
padding: 2rem 0;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.visual-area {
|
| 82 |
+
display: flex;
|
| 83 |
+
flex-direction: column;
|
| 84 |
+
gap: 1rem;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
.upload-container {
|
| 88 |
+
background: var(--card-bg);
|
| 89 |
+
backdrop-filter: blur(20px);
|
| 90 |
+
border: 1px solid var(--glass-border);
|
| 91 |
+
border-radius: 24px;
|
| 92 |
+
padding: 1rem;
|
| 93 |
+
display: flex;
|
| 94 |
+
flex-direction: column;
|
| 95 |
+
align-items: center;
|
| 96 |
+
justify-content: center;
|
| 97 |
+
cursor: pointer;
|
| 98 |
+
transition: all 0.4s;
|
| 99 |
+
min-height: 600px;
|
| 100 |
+
position: relative;
|
| 101 |
+
overflow: hidden;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
.upload-container:hover, .upload-container.drag-over {
|
| 105 |
+
border-color: var(--primary);
|
| 106 |
+
box-shadow: 0 0 40px rgba(0, 255, 136, 0.1);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.preview-wrapper {
|
| 110 |
+
position: relative;
|
| 111 |
+
width: 100%;
|
| 112 |
+
height: 100%;
|
| 113 |
+
display: none;
|
| 114 |
+
justify-content: center;
|
| 115 |
+
align-items: center;
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
#preview-img {
|
| 119 |
+
max-width: 100%;
|
| 120 |
+
max-height: 550px;
|
| 121 |
+
border-radius: 12px;
|
| 122 |
+
display: block;
|
| 123 |
+
object-fit: contain;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
#overlay-svg {
|
| 127 |
+
position: absolute;
|
| 128 |
+
top: 0;
|
| 129 |
+
left: 0;
|
| 130 |
+
width: 100%;
|
| 131 |
+
height: 100%;
|
| 132 |
+
pointer-events: none;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
.upload-icon { font-size: 4rem; margin-bottom: 1rem; }
|
| 136 |
+
.upload-text { font-size: 1.5rem; color: #ccd6f6; }
|
| 137 |
+
.upload-subtext { color: #8892b0; margin-top: 1rem; }
|
| 138 |
+
|
| 139 |
+
.stats-panel {
|
| 140 |
+
display: flex;
|
| 141 |
+
flex-direction: column;
|
| 142 |
+
gap: 1.5rem;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
.stat-card {
|
| 146 |
+
background: rgba(255, 255, 255, 0.03);
|
| 147 |
+
border: 1px solid var(--glass-border);
|
| 148 |
+
border-radius: 20px;
|
| 149 |
+
padding: 1.5rem;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
.stat-label { color: #8892b0; font-size: 0.9rem; margin-bottom: 0.5rem; text-transform: uppercase; }
|
| 153 |
+
.stat-value { font-size: 2rem; font-weight: 600; }
|
| 154 |
+
.conf-bar { height: 8px; background: rgba(255,255,255,0.1); border-radius: 4px; margin-top: 1rem; overflow: hidden; }
|
| 155 |
+
.conf-fill { height: 100%; width: 0%; background: linear-gradient(to right, var(--primary), var(--secondary)); transition: 1s; }
|
| 156 |
+
|
| 157 |
+
.loading-spinner {
|
| 158 |
+
width: 40px; height: 40px;
|
| 159 |
+
border: 3px solid rgba(0, 255, 136, 0.1);
|
| 160 |
+
border-top: 3px solid var(--primary);
|
| 161 |
+
border-radius: 50%;
|
| 162 |
+
animation: spin 1s linear infinite;
|
| 163 |
+
display: none;
|
| 164 |
+
position: absolute;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
@keyframes spin { 100% { transform: rotate(360deg); } }
|
| 168 |
+
|
| 169 |
+
.keypoint-marker {
|
| 170 |
+
fill: var(--primary);
|
| 171 |
+
stroke: #fff;
|
| 172 |
+
stroke-width: 2px;
|
| 173 |
+
filter: drop-shadow(0 0 5px var(--primary));
|
| 174 |
+
animation: pulse 1.5s infinite;
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
@keyframes pulse {
|
| 178 |
+
0% { r: 5; opacity: 1; }
|
| 179 |
+
50% { r: 8; opacity: 0.7; }
|
| 180 |
+
100% { r: 5; opacity: 1; }
|
| 181 |
+
}
|
| 182 |
+
</style>
|
| 183 |
+
</head>
|
| 184 |
+
<body>
|
| 185 |
+
<div class="bg-glow"></div>
|
| 186 |
+
<header>
|
| 187 |
+
<h1>DARTVISION <span style="font-weight: 300; opacity: 0.5;">AI</span></h1>
|
| 188 |
+
<p class="subtitle">Intelligent Labelling & Scoring Dashboard</p>
|
| 189 |
+
</header>
|
| 190 |
+
|
| 191 |
+
<main>
|
| 192 |
+
<section class="visual-area">
|
| 193 |
+
<div class="upload-container" id="drop-zone">
|
| 194 |
+
<div id="initial-state">
|
| 195 |
+
<div class="upload-icon">🎯</div>
|
| 196 |
+
<div class="upload-text">Drag & Drop Image</div>
|
| 197 |
+
<div class="upload-subtext">calibration and scoring results will appear here</div>
|
| 198 |
+
</div>
|
| 199 |
+
<div class="preview-wrapper" id="preview-wrapper">
|
| 200 |
+
<img id="preview-img">
|
| 201 |
+
<svg id="overlay-svg"></svg>
|
| 202 |
+
</div>
|
| 203 |
+
<div class="loading-spinner" id="spinner"></div>
|
| 204 |
+
</div>
|
| 205 |
+
<input type="file" id="file-input" style="display: none;" accept="image/*">
|
| 206 |
+
</section>
|
| 207 |
+
|
| 208 |
+
<aside class="stats-panel">
|
| 209 |
+
<div class="stat-card">
|
| 210 |
+
<div class="stat-label">Model Status</div>
|
| 211 |
+
<div class="stat-value" id="status-text" style="font-size: 1.2rem; color: var(--primary);">System Ready</div>
|
| 212 |
+
</div>
|
| 213 |
+
<div class="stat-card">
|
| 214 |
+
<div class="stat-label">AI Confidence</div>
|
| 215 |
+
<div class="stat-value" id="conf-val">0.0%</div>
|
| 216 |
+
<div class="conf-bar"><div class="conf-fill" id="conf-fill"></div></div>
|
| 217 |
+
</div>
|
| 218 |
+
<div class="stat-card">
|
| 219 |
+
<div class="stat-label">Detection Result</div>
|
| 220 |
+
<div class="stat-value" id="result-text" style="font-size: 1.1rem; opacity: 0.8;">No Image Uploaded</div>
|
| 221 |
+
</div>
|
| 222 |
+
</aside>
|
| 223 |
+
</main>
|
| 224 |
+
|
| 225 |
+
<script>
|
| 226 |
+
const dropZone = document.getElementById('drop-zone');
|
| 227 |
+
const fileInput = document.getElementById('file-input');
|
| 228 |
+
const previewImg = document.getElementById('preview-img');
|
| 229 |
+
const previewWrapper = document.getElementById('preview-wrapper');
|
| 230 |
+
const initialState = document.getElementById('initial-state');
|
| 231 |
+
const svgOverlay = document.getElementById('overlay-svg');
|
| 232 |
+
const spinner = document.getElementById('spinner');
|
| 233 |
+
|
| 234 |
+
dropZone.onclick = () => fileInput.click();
|
| 235 |
+
fileInput.onchange = (e) => handleFile(e.target.files[0]);
|
| 236 |
+
|
| 237 |
+
dropZone.ondragover = (e) => { e.preventDefault(); dropZone.classList.add('drag-over'); };
|
| 238 |
+
dropZone.ondragleave = () => dropZone.classList.remove('drag-over');
|
| 239 |
+
dropZone.ondrop = (e) => { e.preventDefault(); dropZone.classList.remove('drag-over'); handleFile(e.dataTransfer.files[0]); };
|
| 240 |
+
|
| 241 |
+
async function handleFile(file) {
|
| 242 |
+
if (!file || !file.type.startsWith('image/')) return;
|
| 243 |
+
|
| 244 |
+
const reader = new FileReader();
|
| 245 |
+
reader.onload = (e) => {
|
| 246 |
+
previewImg.src = e.target.result;
|
| 247 |
+
previewWrapper.style.display = 'flex';
|
| 248 |
+
initialState.style.display = 'none';
|
| 249 |
+
clearKeypoints();
|
| 250 |
+
};
|
| 251 |
+
reader.readAsDataURL(file);
|
| 252 |
+
|
| 253 |
+
spinner.style.display = 'block';
|
| 254 |
+
const formData = new FormData();
|
| 255 |
+
formData.append('image', file);
|
| 256 |
+
|
| 257 |
+
try {
|
| 258 |
+
const response = await fetch('/api/predict', { method: 'POST', body: formData });
|
| 259 |
+
const data = await response.json();
|
| 260 |
+
spinner.style.display = 'none';
|
| 261 |
+
|
| 262 |
+
if (data.status === 'success') {
|
| 263 |
+
updateUI(data);
|
| 264 |
+
drawKeypoints(data.keypoints);
|
| 265 |
+
}
|
| 266 |
+
} catch (err) {
|
| 267 |
+
spinner.style.display = 'none';
|
| 268 |
+
document.getElementById('status-text').innerText = 'Error';
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
function updateUI(data) {
|
| 273 |
+
const conf = (data.confidence * 100).toFixed(1);
|
| 274 |
+
document.getElementById('conf-val').innerText = `${conf}%`;
|
| 275 |
+
document.getElementById('conf-fill').style.width = `${conf}%`;
|
| 276 |
+
|
| 277 |
+
let resultHtml = `<div style="margin-top: 1rem; font-size: 0.9rem; line-height: 1.5;">${data.message}</div>`;
|
| 278 |
+
if (data.keypoints && data.keypoints.length >= 8) {
|
| 279 |
+
const names = ["Cal 1", "Cal 2", "Cal 3", "Cal 4", "Dart"];
|
| 280 |
+
resultHtml += `<div style="font-size: 0.8rem; opacity: 0.6; margin-top: 5px;">Results & Mapping:</div>`;
|
| 281 |
+
for (let i = 0; i < data.keypoints.length; i += 2) {
|
| 282 |
+
const classIdx = i / 2;
|
| 283 |
+
const name = names[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
|
| 284 |
+
const x = data.keypoints[i].toFixed(3);
|
| 285 |
+
const y = data.keypoints[i+1].toFixed(3);
|
| 286 |
+
|
| 287 |
+
// Get score label if it's a dart
|
| 288 |
+
let scoreLabel = "";
|
| 289 |
+
if (classIdx >= 4 && data.scores && data.scores[classIdx - 4]) {
|
| 290 |
+
scoreLabel = ` <span style="background: #ff4d4d; color: white; padding: 2px 6px; border-radius: 4px; font-weight: 800; margin-left: 10px;">${data.scores[classIdx - 4]}</span>`;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
resultHtml += `<div style="font-size: 0.85rem; padding: 6px 0; border-bottom: 1px solid rgba(255,255,255,0.05); display: flex; align-items: center; justify-content: space-between;">
|
| 294 |
+
<span><span style="color: ${i < 8 ? '#00ff88' : '#ff4d4d'}">${name}</span>: [${x}, ${y}]</span>
|
| 295 |
+
${scoreLabel}
|
| 296 |
+
</div>`;
|
| 297 |
+
}
|
| 298 |
+
}
|
| 299 |
+
document.getElementById('result-text').innerHTML = resultHtml;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
function clearKeypoints() {
|
| 303 |
+
while (svgOverlay.firstChild) svgOverlay.removeChild(svgOverlay.firstChild);
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
function drawKeypoints(pts) {
|
| 307 |
+
clearKeypoints();
|
| 308 |
+
if (!pts || pts.length === 0) return;
|
| 309 |
+
|
| 310 |
+
// Wait for image to render to get actual display dimensions
|
| 311 |
+
setTimeout(() => {
|
| 312 |
+
const rect = previewImg.getBoundingClientRect();
|
| 313 |
+
const wrapperRect = previewWrapper.getBoundingClientRect();
|
| 314 |
+
|
| 315 |
+
// Offset calculation relative to wrapper
|
| 316 |
+
const offsetX = rect.left - wrapperRect.left;
|
| 317 |
+
const offsetY = rect.top - wrapperRect.top;
|
| 318 |
+
const width = rect.width;
|
| 319 |
+
const height = rect.height;
|
| 320 |
+
|
| 321 |
+
const classNames = ["Cal 1", "Cal 2", "Cal 3", "Cal 4", "Dart"];
|
| 322 |
+
for (let i = 0; i < pts.length; i += 2) {
|
| 323 |
+
const x = pts[i] * width + offsetX;
|
| 324 |
+
const y = pts[i+1] * height + offsetY;
|
| 325 |
+
const classIdx = i / 2;
|
| 326 |
+
const name = classNames[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
|
| 327 |
+
|
| 328 |
+
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
|
| 329 |
+
circle.setAttribute("cx", x);
|
| 330 |
+
circle.setAttribute("cy", y);
|
| 331 |
+
circle.setAttribute("r", 6);
|
| 332 |
+
circle.setAttribute("class", "keypoint-marker");
|
| 333 |
+
if (classIdx >= 4) circle.style.fill = "#ff4d4d"; // Red for dart
|
| 334 |
+
|
| 335 |
+
const label = document.createElementNS("http://www.w3.org/2000/svg", "text");
|
| 336 |
+
label.setAttribute("x", x + 10);
|
| 337 |
+
label.setAttribute("y", y - 10);
|
| 338 |
+
label.setAttribute("fill", classIdx < 4 ? "#00ff88" : "#ff4d4d");
|
| 339 |
+
label.setAttribute("font-size", "14px");
|
| 340 |
+
label.setAttribute("font-weight", "600");
|
| 341 |
+
label.textContent = name;
|
| 342 |
+
|
| 343 |
+
svgOverlay.appendChild(circle);
|
| 344 |
+
svgOverlay.appendChild(label);
|
| 345 |
+
}
|
| 346 |
+
}, 50);
|
| 347 |
+
}
|
| 348 |
+
</script>
|
| 349 |
+
</body>
|
| 350 |
+
</html>
|