kapil commited on
Commit
a6578e7
·
1 Parent(s): 9874885

update the code and md file

Browse files
Files changed (8) hide show
  1. README.md +119 -0
  2. src/data.rs +6 -2
  3. src/inference.rs +28 -11
  4. src/loss.rs +2 -1
  5. src/main.rs +1 -1
  6. src/model.rs +16 -16
  7. src/scoring.rs +20 -15
  8. src/server.rs +244 -51
README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RustAutoScoreEngine
2
+ ### High-Performance AI Dart Scoring Powered by Rust & Burn
3
+
4
+ <div align="center">
5
+
6
+ [![Rust](https://img.shields.io/badge/Rust-1.75%2B-orange?style=for-the-badge&logo=rust)](https://www.rust-lang.org/)
7
+ [![Burn](https://img.shields.io/badge/Burn-AI--Framework-red?style=for-the-badge)](https://burn.dev/)
8
+ [![WGPU](https://img.shields.io/badge/Backend-WGPU%20/%20Cuda-blue?style=for-the-badge)](https://github.com/gfx-rs/wgpu)
9
+ [![License](https://img.shields.io/badge/License-MIT-purple?style=for-the-badge)](LICENSE)
10
+
11
+ **A professional-grade, real-time dart scoring engine built entirely in Rust.**
12
+
13
+ Using the **Burn Deep Learning Framework**, this project achieves sub-millisecond inference and high-precision keypoint detection for automatic dart game tracking. The model optimization pipeline is built using modern Rust patterns for maximum safety and performance.
14
+
15
+ </div>
16
+
17
+ ---
18
+
19
+ ## Features
20
+
21
+ - **Optimized Inference**: Powered by Rust & WGPU for hardware-accelerated performance on Windows, Linux, and macOS.
22
+ - **Multi-Scale Keypoint Detection**: Enhanced YOLO-style heads for detecting dart tips and calibration corners.
23
+ - **BDO Logic Integrated**: Real-time sector calculation based on official board geometry and calibration symmetry.
24
+ - **Modern Web Dashboard**: Axum-based visual interface to monitor detections, scores, and latency in real-time.
25
+ - **Robust Calibration**: Automatic symmetry estimation to recover missing calibration points.
26
+
27
+ ---
28
+
29
+ ## Dataset and Preparation
30
+ The model is trained on the primary dataset used for high-precision dart detection.
31
+
32
+ - **Download Link**: [Dataset Resources (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
33
+ - **Resolution**: 800x800 pre-cropped high-resolution images.
34
+ - **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
35
+
36
+ ---
37
+
38
+ ## Installation
39
+
40
+ ### 1. Install Rust
41
+ If you do not have Rust installed, use the official installation script:
42
+
43
+ ```bash
44
+ # Official Installation
45
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
46
+ ```
47
+
48
+ ### 2. Clone and Build
49
+ ```bash
50
+ git clone https://github.com/iambhabha/RustAutoScoreEngine.git
51
+ cd RustAutoScoreEngine
52
+ cargo build --release
53
+ ```
54
+
55
+ ---
56
+
57
+ ## Quick Start Guide
58
+
59
+ ### Step 1: Training the AI Model
60
+ To optimize the neural network for your local environment, run the training mode:
61
+ ```bash
62
+ # Starts the training cycle (Configured for 50 Epochs)
63
+ cargo run
64
+ ```
65
+ *Tip: Allow the loss to converge below 0.05 for optimal results.*
66
+
67
+ ### Step 2: Running the Dashboard
68
+ After training is complete (generating the `model_weights.bin` file), launch the testing interface:
69
+ ```bash
70
+ # Starts the Axum web server
71
+ cargo run -- gui
72
+ ```
73
+ **Features:**
74
+ - **Image Upload**: Test local image samples via the dashboard.
75
+ - **Point Visualization**: Inspect detected calibration points and dart locations.
76
+ - **Automatic Scoring**: Instant sector calculation and latency reporting.
77
+
78
+ ---
79
+
80
+ ## Mobile Deployment
81
+
82
+ This engine is built on Burn, supporting multiple paths for Android and iOS integration:
83
+
84
+ ### Path A: Native Rust
85
+ Package the engine as a library for direct hardware-accelerated execution on mobile targets.
86
+ - **Backend**: burn-wgpu with Vulkan (Android) or Metal (iOS).
87
+ - **Integration**: JNI (Android) or FFI (iOS) calls from native code.
88
+
89
+ ### Path B: Weight Migration to TFLite/ONNX
90
+ - **TFLite**: Use the companion export scripts to generate a TensorFlow Lite bundle.
91
+ - **ONNX**: Utilize ONNX Runtime (ORT) for high-performance cross-platform execution.
92
+
93
+ ---
94
+
95
+ ## Technical Status and Contributing
96
+
97
+ > [!IMPORTANT]
98
+ > This project is currently in the experimental phase. We are actively refining the coordinate regression logic to ensure maximum precision across diverse board angles.
99
+
100
+ **Current Priorities:**
101
+ - Enhancing offset regression stability.
102
+ - Memory optimization for low-VRAM devices.
103
+
104
+ **Contribution Guidelines:**
105
+ If you encounter a bug or wish to provide performance optimizations, please submit a Pull Request.
106
+
107
+ ---
108
+
109
+ ## Resources
110
+
111
+ - **Original Inspiration**: [Paper: Keypoints as Objects for Automatic Scorekeeping](https://arxiv.org/abs/2105.09880)
112
+ - **Model Training Resources**: [Download from Google Drive](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
113
+ - **Official Documentation Reference**: [IEEE Dataport Dataset](https://ieee-dataport.org/open-access/deepdarts-dataset)
114
+
115
+ ---
116
+
117
+ <div align="center">
118
+ Made by the Rust AI Community
119
+ </div>
src/data.rs CHANGED
@@ -90,12 +90,16 @@ impl<B: Backend> DartBatcher<B> {
90
  let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
91
  let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
92
 
 
 
 
 
93
  let cls = if i < 4 { i + 1 } else { 0 };
94
  let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
95
 
96
  // TF order: [x,y,w,h,obj,p0..p4]
97
- target_raw[base_idx + 0 * grid_size * grid_size] = p[0]; // X
98
- target_raw[base_idx + 1 * grid_size * grid_size] = p[1]; // Y
99
  target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
100
  target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
101
  target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
 
90
  let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
91
  let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
92
 
93
+ // Use Grid-Relative Coordinates (Relative to cell top-left)
94
+ let tx = p[0] * grid_size as f32 - gx as f32;
95
+ let ty = p[1] * grid_size as f32 - gy as f32;
96
+
97
  let cls = if i < 4 { i + 1 } else { 0 };
98
  let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
99
 
100
  // TF order: [x,y,w,h,obj,p0..p4]
101
+ target_raw[base_idx + 0 * grid_size * grid_size] = tx; // X (offset in cell)
102
+ target_raw[base_idx + 1 * grid_size * grid_size] = ty; // Y (offset in cell)
103
  target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
104
  target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
105
  target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
src/inference.rs CHANGED
@@ -1,15 +1,15 @@
 
1
  use burn::module::Module;
2
  use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
3
  use burn::tensor::backend::Backend;
4
  use burn::tensor::{Tensor, TensorData};
5
- use crate::model::DartVisionModel;
6
- use image::{GenericImageView, DynamicImage};
7
 
8
  pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
9
  println!("🔍 Loading model for inference...");
10
  let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
11
  let model: DartVisionModel<B> = DartVisionModel::new(device);
12
-
13
  // Load weights
14
  let record = Recorder::load(&recorder, "model_weights".into(), device)
15
  .expect("Failed to load weights. Make sure model_weights.bin exists.");
@@ -18,12 +18,22 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
18
  println!("🖼️ Processing image: {}...", image_path);
19
  let img = image::open(image_path).expect("Failed to open image");
20
  let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
21
- let pixels: Vec<f32> = resized.to_rgb8().pixels()
22
- .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
 
 
 
 
 
 
 
 
23
  .collect();
24
 
25
  let data = TensorData::new(pixels, [800, 800, 3]);
26
- let input = Tensor::<B, 3>::from_data(data, device).unsqueeze::<4>().permute([0, 3, 1, 2]);
 
 
27
 
28
  println!("🚀 Running MODEL Prediction...");
29
  let (out16, _out32) = model.forward(input);
@@ -31,17 +41,24 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
31
  // Post-process out16 (size [1, 30, 100, 100])
32
  // Decode objectness part (Channel 4 for Anchor 0)
33
  let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
34
-
35
  // Find highest confidence cell
36
  let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
37
- let confidence: f32 = max_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
38
-
 
 
 
 
39
  println!("--------------------------------------------------");
40
  println!("📊 RESULTS FOR: {}", image_path);
41
  println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
42
-
43
  if confidence > 0.05 {
44
- println!("✅ Model found something! Confidence Score: {:.4}", confidence);
 
 
 
45
  } else {
46
  println!("⚠️ Model confidence is too low. Training incomplete?");
47
  }
 
1
+ use crate::model::DartVisionModel;
2
  use burn::module::Module;
3
  use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
4
  use burn::tensor::backend::Backend;
5
  use burn::tensor::{Tensor, TensorData};
6
+ use image::{DynamicImage, GenericImageView};
 
7
 
8
  pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
9
  println!("🔍 Loading model for inference...");
10
  let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
11
  let model: DartVisionModel<B> = DartVisionModel::new(device);
12
+
13
  // Load weights
14
  let record = Recorder::load(&recorder, "model_weights".into(), device)
15
  .expect("Failed to load weights. Make sure model_weights.bin exists.");
 
18
  println!("🖼️ Processing image: {}...", image_path);
19
  let img = image::open(image_path).expect("Failed to open image");
20
  let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
21
+ let pixels: Vec<f32> = resized
22
+ .to_rgb8()
23
+ .pixels()
24
+ .flat_map(|p| {
25
+ vec![
26
+ p[0] as f32 / 255.0,
27
+ p[1] as f32 / 255.0,
28
+ p[2] as f32 / 255.0,
29
+ ]
30
+ })
31
  .collect();
32
 
33
  let data = TensorData::new(pixels, [800, 800, 3]);
34
+ let input = Tensor::<B, 3>::from_data(data, device)
35
+ .unsqueeze::<4>()
36
+ .permute([0, 3, 1, 2]);
37
 
38
  println!("🚀 Running MODEL Prediction...");
39
  let (out16, _out32) = model.forward(input);
 
41
  // Post-process out16 (size [1, 30, 100, 100])
42
  // Decode objectness part (Channel 4 for Anchor 0)
43
  let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
44
+
45
  // Find highest confidence cell
46
  let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
47
+ let confidence: f32 = max_val
48
+ .to_data()
49
+ .convert::<f32>()
50
+ .as_slice::<f32>()
51
+ .unwrap()[0];
52
+
53
  println!("--------------------------------------------------");
54
  println!("📊 RESULTS FOR: {}", image_path);
55
  println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
56
+
57
  if confidence > 0.05 {
58
+ println!(
59
+ "✅ Model found something! Confidence Score: {:.4}",
60
+ confidence
61
+ );
62
  } else {
63
  println!("⚠️ Model confidence is too low. Training incomplete?");
64
  }
src/loss.rs CHANGED
@@ -32,9 +32,10 @@ pub fn diou_loss<B: Backend>(
32
  .mul_scalar(5.0); // Boost class learning
33
 
34
  // 4. Coordinates (Channels 0-3) - Only learn when object exists
 
35
  let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
36
  let b_xy_target = t.clone().narrow(2, 0, 2);
37
- let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(5.0);
38
 
39
  let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
40
  let b_wh_target = t.clone().narrow(2, 2, 2);
 
32
  .mul_scalar(5.0); // Boost class learning
33
 
34
  // 4. Coordinates (Channels 0-3) - Only learn when object exists
35
+ // 2. Coordinate Loss (MSE on relative offsets) - Weighted x10 for precision
36
  let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
37
  let b_xy_target = t.clone().narrow(2, 0, 2);
38
+ let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(10.0);
39
 
40
  let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
41
  let b_wh_target = t.clone().narrow(2, 2, 2);
src/main.rs CHANGED
@@ -32,7 +32,7 @@ fn main() {
32
  let dataset_path = "dataset/labels.json";
33
 
34
  let config = TrainingConfig {
35
- num_epochs: 10,
36
  batch_size: 1,
37
  lr: 1e-3,
38
  };
 
32
  let dataset_path = "dataset/labels.json";
33
 
34
  let config = TrainingConfig {
35
+ num_epochs: 50,
36
  batch_size: 1,
37
  lr: 1e-3,
38
  };
src/model.rs CHANGED
@@ -30,40 +30,40 @@ impl<B: Backend> ConvBlock<B> {
30
 
31
  #[derive(Module, Debug)]
32
  pub struct DartVisionModel<B: Backend> {
33
- // Lean architecture: High resolution (800x800) but low channel count to fix GPU OOM
34
- l1: ConvBlock<B>, // 3 -> 16
35
  p1: MaxPool2d,
36
- l2: ConvBlock<B>, // 16 -> 16
37
  p2: MaxPool2d,
38
- l3: ConvBlock<B>, // 16 -> 32
39
  p3: MaxPool2d,
40
- l4: ConvBlock<B>, // 32 -> 32
41
  p4: MaxPool2d,
42
- l5: ConvBlock<B>, // 32 -> 64
43
- l6: ConvBlock<B>, // 64 -> 64
44
 
45
- head_32: Conv2d<B>, // Final detection head
46
  }
47
 
48
  impl<B: Backend> DartVisionModel<B> {
49
  pub fn new(device: &B::Device) -> Self {
50
- let l1 = ConvBlock::new(3, 16, [3, 3], device);
51
  let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
52
 
53
- let l2 = ConvBlock::new(16, 16, [3, 3], device);
54
  let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
55
 
56
- let l3 = ConvBlock::new(16, 32, [3, 3], device);
57
  let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
58
 
59
- let l4 = ConvBlock::new(32, 32, [3, 3], device);
60
  let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
61
 
62
- let l5 = ConvBlock::new(32, 64, [3, 3], device);
63
- let l6 = ConvBlock::new(64, 64, [3, 3], device);
64
 
65
- // 30 channels = 3 anchors * (x,y,w,h,obj,p0...p4)
66
- let head_32 = Conv2dConfig::new([64, 30], [1, 1]).init(device);
67
 
68
  Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
69
  }
 
30
 
31
  #[derive(Module, Debug)]
32
  pub struct DartVisionModel<B: Backend> {
33
+ // Increased capacity: High resolution but enough width to map complex features
34
+ l1: ConvBlock<B>, // 3 -> 32
35
  p1: MaxPool2d,
36
+ l2: ConvBlock<B>, // 32 -> 32
37
  p2: MaxPool2d,
38
+ l3: ConvBlock<B>, // 32 -> 64
39
  p3: MaxPool2d,
40
+ l4: ConvBlock<B>, // 64 -> 64
41
  p4: MaxPool2d,
42
+ l5: ConvBlock<B>, // 64 -> 128
43
+ l6: ConvBlock<B>, // 128 -> 128
44
 
45
+ head_32: Conv2d<B>, // Final detection head (30 channels for 3 anchors)
46
  }
47
 
48
  impl<B: Backend> DartVisionModel<B> {
49
  pub fn new(device: &B::Device) -> Self {
50
+ let l1 = ConvBlock::new(3, 32, [3, 3], device);
51
  let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
52
 
53
+ let l2 = ConvBlock::new(32, 32, [3, 3], device);
54
  let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
55
 
56
+ let l3 = ConvBlock::new(32, 64, [3, 3], device);
57
  let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
58
 
59
+ let l4 = ConvBlock::new(64, 64, [3, 3], device);
60
  let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
61
 
62
+ let l5 = ConvBlock::new(64, 128, [3, 3], device);
63
+ let l6 = ConvBlock::new(128, 128, [3, 3], device);
64
 
65
+ // 30 channels = 3 anchors * (x,y,w,h,obj,dart,cal1,cal2,cal3,cal4)
66
+ let head_32 = Conv2dConfig::new([128, 30], [1, 1]).init(device);
67
 
68
  Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
69
  }
src/scoring.rs CHANGED
@@ -23,44 +23,49 @@ impl Default for ScoringConfig {
23
  pub fn get_board_dict() -> HashMap<i32, &'static str> {
24
  let mut m = HashMap::new();
25
  // BDO standard mapping based on degrees
26
- let slices = ["13", "4", "18", "1", "20", "5", "12", "9", "14", "11", "8", "16", "7", "19", "3", "17", "2", "15", "10", "6"];
 
 
 
27
  for (i, &s) in slices.iter().enumerate() {
28
  m.insert(i as i32, s);
29
  }
30
  m
31
  }
32
 
33
- pub fn calculate_dart_score(cal_pts: &[[f32; 2]], dart_pt: &[f32; 2], config: &ScoringConfig) -> (i32, String) {
 
 
 
 
34
  // 1. Calculate Center (Average of 4 calibration points)
35
  let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
36
  let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
37
 
38
  // 2. Calculate average radius to boundary (doubles wire)
39
- let avg_r_px = cal_pts.iter()
 
40
  .map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
41
- .sum::<f32>() / 4.0;
 
42
 
43
  // 3. Relative distance of dart from center
44
  let dx = dart_pt[0] - cx;
45
  let dy = dart_pt[1] - cy;
46
  let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
47
-
48
  // Scale distance relative to BDO double radius
49
  let dist_scaled = (dist_px / avg_r_px) * config.r_double;
50
 
51
  // 4. Calculate Angle (0 is 3 o'clock, CCW)
52
  let mut angle_deg = (-dy).atan2(dx).to_degrees();
53
- if angle_deg < 0.0 { angle_deg += 360.0; }
54
-
55
- // Board is rotated such that 20 is at top (90 deg)
56
- // Sector width is 18 deg. Sector 20 is centered at 90 deg.
57
- // 90 deg is index 4 in slices (13, 4, 18, 1, 20...)
58
- // Each index is 18 deg. Offset = 4 * 18 = 72? No.
59
- // Let's use the standard mapping: (angle / 18)
60
- // Wait, the BOARD_DICT in Python uses int(angle / 18) where angle is 0-360.
61
- // We need to match the slice orientation.
62
  let board_dict = get_board_dict();
63
- let sector_idx = ((angle_deg / 18.0).floor() as i32) % 20;
64
  let sector_num = board_dict.get(&sector_idx).unwrap_or(&"0");
65
 
66
  // 5. Determine multipliers based on scaled distance
 
23
  pub fn get_board_dict() -> HashMap<i32, &'static str> {
24
  let mut m = HashMap::new();
25
  // BDO standard mapping based on degrees
26
+ let slices = [
27
+ "6", "13", "4", "18", "1", "20", "5", "12", "9", "14", "11", "8", "16", "7", "19", "3",
28
+ "17", "2", "15", "10",
29
+ ];
30
  for (i, &s) in slices.iter().enumerate() {
31
  m.insert(i as i32, s);
32
  }
33
  m
34
  }
35
 
36
+ pub fn calculate_dart_score(
37
+ cal_pts: &[[f32; 2]],
38
+ dart_pt: &[f32; 2],
39
+ config: &ScoringConfig,
40
+ ) -> (i32, String) {
41
  // 1. Calculate Center (Average of 4 calibration points)
42
  let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
43
  let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
44
 
45
  // 2. Calculate average radius to boundary (doubles wire)
46
+ let avg_r_px = cal_pts
47
+ .iter()
48
  .map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
49
+ .sum::<f32>()
50
+ / 4.0;
51
 
52
  // 3. Relative distance of dart from center
53
  let dx = dart_pt[0] - cx;
54
  let dy = dart_pt[1] - cy;
55
  let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
56
+
57
  // Scale distance relative to BDO double radius
58
  let dist_scaled = (dist_px / avg_r_px) * config.r_double;
59
 
60
  // 4. Calculate Angle (0 is 3 o'clock, CCW)
61
  let mut angle_deg = (-dy).atan2(dx).to_degrees();
62
+ if angle_deg < 0.0 {
63
+ angle_deg += 360.0;
64
+ }
65
+
66
+ // Center sectors by adding 9 degrees (half-sector width)
 
 
 
 
67
  let board_dict = get_board_dict();
68
+ let sector_idx = (((angle_deg + 9.0) / 18.0).floor() as i32) % 20;
69
  let sector_num = board_dict.get(&sector_idx).unwrap_or(&"0");
70
 
71
  // 5. Determine multipliers based on scaled distance
src/server.rs CHANGED
@@ -1,14 +1,14 @@
 
1
  use axum::{
2
  extract::{DefaultBodyLimit, Multipart, State},
3
  response::{Html, Json},
4
  routing::{get, post},
5
  Router,
6
  };
 
 
7
  use burn::prelude::*;
8
  use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
9
- use burn::backend::Wgpu;
10
- use burn::backend::wgpu::WgpuDevice;
11
- use crate::model::DartVisionModel;
12
  use serde_json::json;
13
  use std::net::SocketAddr;
14
  use std::sync::Arc;
@@ -47,64 +47,235 @@ pub async fn start_gui(device: WgpuDevice) {
47
  let model = model.load_record(record);
48
 
49
  while let Some(req) = rx.blocking_recv() {
 
 
50
  let img = image::load_from_memory(&req.image_bytes).unwrap();
51
  let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
52
- let pixels: Vec<f32> = resized.to_rgb8().pixels()
53
- .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
 
 
 
 
 
 
 
 
54
  .collect();
55
 
56
  let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
57
- let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
 
58
 
59
  let (out16, _) = model.forward(input);
60
-
 
 
 
 
 
 
 
 
 
 
61
  let mut final_points = vec![0.0f32; 8]; // 4 corners
62
  let mut max_conf = 0.0f32;
63
 
64
- // 1. Extract Objectness with Sigmoid
65
- let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
66
-
67
- // 2. Extract best calibration corner for each class 1 to 4 (Grid 26x26 = 676)
68
  for cls_idx in 1..=4 {
69
- let prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5 + cls_idx, 1));
70
- let score = obj.clone().mul(prob);
71
- let (val, idx) = score.reshape([1, 676]).max_dim_with_indices(1);
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
74
  let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
75
- let gy = f_idx / 26;
76
  let gx = f_idx % 26;
 
77
 
78
- // Use Sigmoid for Coordinates (matching new loss logic)
79
- let px = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
80
- .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
81
- let py = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
82
- .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- final_points[(cls_idx-1)*2] = px;
85
- final_points[(cls_idx-1)*2+1] = py;
86
- if s > max_conf { max_conf = s; }
 
 
 
 
 
 
 
87
  }
88
 
89
- // 3. Extract best dart (Class 0)
90
- let d_prob = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 5, 1));
91
- let d_score = obj.clone().mul(d_prob);
92
- let (d_val, d_idx) = d_score.reshape([1, 676]).max_dim_with_indices(1);
93
- let ds = d_val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
94
- if ds > 0.1 {
95
- let f_idx = d_idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
96
- let gy = f_idx / 26;
97
- let gx = f_idx % 26;
98
- let dx = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 0, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
99
- .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
100
- let dy = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 1, 1).slice([0..1, 0..1, gy..gy+1, gx..gx+1]))
101
- .to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
102
- final_points.push(dx);
103
- final_points.push(dy);
104
  }
105
 
106
  let mut final_scores = vec![];
107
-
108
  // Calculate scores if we have calibration points and at least one dart
109
  if final_points.len() >= 10 {
110
  use crate::scoring::{calculate_dart_score, ScoringConfig};
@@ -115,22 +286,32 @@ pub async fn start_gui(device: WgpuDevice) {
115
  [final_points[4], final_points[5]],
116
  [final_points[6], final_points[7]],
117
  ];
118
-
119
  for dart_chunk in final_points[8..].chunks(2) {
120
  if dart_chunk.len() == 2 {
121
  let dart_pt = [dart_chunk[0], dart_chunk[1]];
122
- let (_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
123
- final_scores.push(label);
 
124
  }
125
  }
126
  }
127
 
128
- println!("🎯 [Detection Result] Confidence: {:.2}%", max_conf * 100.0);
 
 
 
129
  let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
130
  for (i, pts) in final_points.chunks(2).enumerate() {
131
  let name = class_names.get(i).unwrap_or(&"Dart");
132
- let label = final_scores.get(i.saturating_sub(4)).cloned().unwrap_or_default();
133
- println!(" - {}: [x: {:.3}, y: {:.3}] {}", name, pts[0], pts[1], label);
 
 
 
 
 
 
134
  }
135
 
136
  let _ = req.response_tx.send(PredictResult {
@@ -144,7 +325,10 @@ pub async fn start_gui(device: WgpuDevice) {
144
  let state = Arc::new(tx);
145
 
146
  let app = Router::new()
147
- .route("/", get(|| async { Html(include_str!("../static/index.html")) }))
 
 
 
148
  .route("/api/predict", post(predict_handler))
149
  .with_state(state)
150
  .layer(DefaultBodyLimit::max(10 * 1024 * 1024))
@@ -165,17 +349,26 @@ async fn predict_handler(
165
  Err(_) => continue,
166
  };
167
  let (res_tx, res_rx) = oneshot::channel();
168
- let _ = tx.send(PredictRequest { image_bytes: bytes, response_tx: res_tx }).await;
169
- let result = res_rx.await.unwrap_or(PredictResult { confidence: 0.0, keypoints: vec![] });
 
 
 
 
 
 
 
 
 
170
 
171
  return Json(json!({
172
  "status": "success",
173
  "confidence": result.confidence,
174
  "keypoints": result.keypoints,
175
  "scores": result.scores,
176
- "message": if result.confidence > 0.1 {
177
  format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
178
- } else {
179
  "⚠️ Low confidence detection - no dart score could be verified.".to_string()
180
  }
181
  }));
 
1
+ use crate::model::DartVisionModel;
2
  use axum::{
3
  extract::{DefaultBodyLimit, Multipart, State},
4
  response::{Html, Json},
5
  routing::{get, post},
6
  Router,
7
  };
8
+ use burn::backend::wgpu::WgpuDevice;
9
+ use burn::backend::Wgpu;
10
  use burn::prelude::*;
11
  use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
 
 
 
12
  use serde_json::json;
13
  use std::net::SocketAddr;
14
  use std::sync::Arc;
 
47
  let model = model.load_record(record);
48
 
49
  while let Some(req) = rx.blocking_recv() {
50
+ let start_time = std::time::Instant::now();
51
+
52
  let img = image::load_from_memory(&req.image_bytes).unwrap();
53
  let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
54
+ let pixels: Vec<f32> = resized
55
+ .to_rgb8()
56
+ .pixels()
57
+ .flat_map(|p| {
58
+ vec![
59
+ p[0] as f32 / 255.0,
60
+ p[1] as f32 / 255.0,
61
+ p[2] as f32 / 255.0,
62
+ ]
63
+ })
64
  .collect();
65
 
66
  let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
67
+ let input =
68
+ Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
69
 
70
  let (out16, _) = model.forward(input);
71
+
72
+ // 1. Reshape to separate anchors: [1, 3, 10, 26, 26]
73
+ let out_reshaped = out16.reshape([1, 3, 10, 26, 26]);
74
+
75
+ // 1.5 Debug: Raw Statistics
76
+ println!(
77
+ "🔍 [Model Stats] Raw Min: {:.4}, Max: {:.4}",
78
+ out_reshaped.clone().min().into_scalar(),
79
+ out_reshaped.clone().max().into_scalar()
80
+ );
81
+
82
  let mut final_points = vec![0.0f32; 8]; // 4 corners
83
  let mut max_conf = 0.0f32;
84
 
85
+ // 2. Extract best calibration corner for each class 1 to 4
 
 
 
86
  for cls_idx in 1..=4 {
87
+ let mut best_s = -1.0f32;
88
+ let mut best_pt = [0.0f32; 2];
89
+ let mut best_anchor = 0;
90
+ let mut best_grid = (0, 0);
91
+
92
+ for anchor in 0..3 {
93
+ let obj = burn::tensor::activation::sigmoid(
94
+ out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
95
+ );
96
+ let prob = burn::tensor::activation::sigmoid(
97
+ out_reshaped
98
+ .clone()
99
+ .narrow(1, anchor, 1)
100
+ .narrow(2, 5 + cls_idx, 1),
101
+ );
102
+ let score = obj.mul(prob);
103
+
104
+ let (val, idx) = score.reshape([1, 676]).max_dim_with_indices(1);
105
+ let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
106
+ if s > best_s {
107
+ best_s = s;
108
+ best_anchor = anchor;
109
+ let f_idx =
110
+ idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
111
+ best_grid = (f_idx % 26, f_idx / 26);
112
+
113
+ let sx = burn::tensor::activation::sigmoid(
114
+ out_reshaped
115
+ .clone()
116
+ .narrow(1, anchor, 1)
117
+ .narrow(2, 0, 1)
118
+ .slice([
119
+ 0..1,
120
+ 0..1,
121
+ 0..1,
122
+ best_grid.1..best_grid.1 + 1,
123
+ best_grid.0..best_grid.0 + 1,
124
+ ]),
125
+ )
126
+ .to_data()
127
+ .convert::<f32>()
128
+ .as_slice::<f32>()
129
+ .unwrap()[0];
130
+ let sy = burn::tensor::activation::sigmoid(
131
+ out_reshaped
132
+ .clone()
133
+ .narrow(1, anchor, 1)
134
+ .narrow(2, 1, 1)
135
+ .slice([
136
+ 0..1,
137
+ 0..1,
138
+ 0..1,
139
+ best_grid.1..best_grid.1 + 1,
140
+ best_grid.0..best_grid.0 + 1,
141
+ ]),
142
+ )
143
+ .to_data()
144
+ .convert::<f32>()
145
+ .as_slice::<f32>()
146
+ .unwrap()[0];
147
+
148
+ // Reconstruct Absolute Normalized Coord (0-1)
149
+ best_pt = [
150
+ (best_grid.0 as f32 + sx) / 26.0,
151
+ (best_grid.1 as f32 + sy) / 26.0,
152
+ ];
153
+ }
154
+ }
155
+
156
+ final_points[(cls_idx - 1) * 2] = best_pt[0];
157
+ final_points[(cls_idx - 1) * 2 + 1] = best_pt[1];
158
+ if best_s > max_conf {
159
+ max_conf = best_s;
160
+ }
161
+ println!(
162
+ " [Debug Cal{}] Anchor: {}, Conf: {:.4}, Cell: {:?}, Coord: [{:.3}, {:.3}]",
163
+ cls_idx, best_anchor, best_s, best_grid, best_pt[0], best_pt[1]
164
+ );
165
+ }
166
+
167
+ // 3. Calibration Estimation (Python logic: est_cal_pts)
168
+ // If one calibration point is missing, estimate it using symmetry
169
+ let mut valid_cal_count = 0;
170
+ let mut missing_idx = -1;
171
+ for i in 0..4 {
172
+ if final_points[i*2] > 0.01 || final_points[i*2+1] > 0.01 {
173
+ valid_cal_count += 1;
174
+ } else {
175
+ missing_idx = i as i32;
176
+ }
177
+ }
178
+
179
+ if valid_cal_count == 3 {
180
+ println!("⚠️ [Calibration Recovery] Estimating missing point Cal{}...", missing_idx + 1);
181
+ match missing_idx {
182
+ 0 | 1 => { // Top points missing, use bottom points center
183
+ let cx = (final_points[4] + final_points[6]) / 2.0;
184
+ let cy = (final_points[5] + final_points[7]) / 2.0;
185
+ if missing_idx == 0 {
186
+ final_points[0] = 2.0 * cx - final_points[2];
187
+ final_points[1] = 2.0 * cy - final_points[3];
188
+ } else {
189
+ final_points[2] = 2.0 * cx - final_points[0];
190
+ final_points[3] = 2.0 * cy - final_points[1];
191
+ }
192
+ },
193
+ 2 | 3 => { // Bottom points missing, use top points center
194
+ let cx = (final_points[0] + final_points[2]) / 2.0;
195
+ let cy = (final_points[1] + final_points[3]) / 2.0;
196
+ if missing_idx == 2 {
197
+ final_points[4] = 2.0 * cx - final_points[6];
198
+ final_points[5] = 2.0 * cy - final_points[7];
199
+ } else {
200
+ final_points[6] = 2.0 * cx - final_points[4];
201
+ final_points[7] = 2.0 * cy - final_points[5];
202
+ }
203
+ },
204
+ _ => {}
205
+ }
206
+ }
207
+
208
+ // 4. Extract best dart (Class 0) - Find candidates across all anchors
209
+ println!(" [Debug Dart] Searching for Candidates...");
210
+ let mut dart_candidates = vec![];
211
+ for anchor in 0..3 {
212
+ let obj = burn::tensor::activation::sigmoid(
213
+ out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
214
+ );
215
+ let prob = burn::tensor::activation::sigmoid(
216
+ out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1),
217
+ );
218
+ let score = obj.mul(prob).reshape([1, 676]);
219
+
220
+ let (val, idx) = score.max_dim_with_indices(1);
221
  let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
222
  let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
223
+
224
  let gx = f_idx % 26;
225
+ let gy = f_idx / 26;
226
 
227
+ let dsx = burn::tensor::activation::sigmoid(
228
+ out_reshaped
229
+ .clone()
230
+ .narrow(1, anchor, 1)
231
+ .narrow(2, 0, 1)
232
+ .slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
233
+ )
234
+ .to_data()
235
+ .convert::<f32>()
236
+ .as_slice::<f32>()
237
+ .unwrap()[0];
238
+ let dsy = burn::tensor::activation::sigmoid(
239
+ out_reshaped
240
+ .clone()
241
+ .narrow(1, anchor, 1)
242
+ .narrow(2, 1, 1)
243
+ .slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
244
+ )
245
+ .to_data()
246
+ .convert::<f32>()
247
+ .as_slice::<f32>()
248
+ .unwrap()[0];
249
 
250
+ let dx = (gx as f32 + dsx) / 26.0;
251
+ let dy = (gy as f32 + dsy) / 26.0;
252
+
253
+ if s > 0.005 {
254
+ println!(
255
+ " - A{} Best Cell: ({},{}), Conf: {:.4}, Coord: [{:.3}, {:.3}]",
256
+ anchor, gx, gy, s, dx, dy
257
+ );
258
+ dart_candidates.push((s, [dx, dy]));
259
+ }
260
  }
261
 
262
+ // Pick the best dart candidate among all anchors
263
+ dart_candidates
264
+ .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
265
+ if let Some((s, pt)) = dart_candidates.first() {
266
+ if *s > 0.05 {
267
+ final_points.push(pt[0]);
268
+ final_points.push(pt[1]);
269
+ println!(
270
+ " ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}",
271
+ s * 100.0,
272
+ pt
273
+ );
274
+ }
 
 
275
  }
276
 
277
  let mut final_scores = vec![];
278
+
279
  // Calculate scores if we have calibration points and at least one dart
280
  if final_points.len() >= 10 {
281
  use crate::scoring::{calculate_dart_score, ScoringConfig};
 
286
  [final_points[4], final_points[5]],
287
  [final_points[6], final_points[7]],
288
  ];
289
+
290
  for dart_chunk in final_points[8..].chunks(2) {
291
  if dart_chunk.len() == 2 {
292
  let dart_pt = [dart_chunk[0], dart_chunk[1]];
293
+ let (s_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
294
+ final_scores.push(label.clone());
295
+ println!(" [Debug Score] Label: {} (Val: {})", label, s_val);
296
  }
297
  }
298
  }
299
 
300
+ let duration = start_time.elapsed();
301
+ println!("⚡ [Inference Performance] Total Latency: {:.2?}", duration);
302
+
303
+ println!("🎯 [Final Result] Top Confidence: {:.2}%", max_conf * 100.0);
304
  let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
305
  for (i, pts) in final_points.chunks(2).enumerate() {
306
  let name = class_names.get(i).unwrap_or(&"Dart");
307
+ let label = final_scores
308
+ .get(i.saturating_sub(4))
309
+ .cloned()
310
+ .unwrap_or_default();
311
+ println!(
312
+ " - {}: [x: {:.3}, y: {:.3}] {}",
313
+ name, pts[0], pts[1], label
314
+ );
315
  }
316
 
317
  let _ = req.response_tx.send(PredictResult {
 
325
  let state = Arc::new(tx);
326
 
327
  let app = Router::new()
328
+ .route(
329
+ "/",
330
+ get(|| async { Html(include_str!("../static/index.html")) }),
331
+ )
332
  .route("/api/predict", post(predict_handler))
333
  .with_state(state)
334
  .layer(DefaultBodyLimit::max(10 * 1024 * 1024))
 
349
  Err(_) => continue,
350
  };
351
  let (res_tx, res_rx) = oneshot::channel();
352
+ let _ = tx
353
+ .send(PredictRequest {
354
+ image_bytes: bytes,
355
+ response_tx: res_tx,
356
+ })
357
+ .await;
358
+ let result = res_rx.await.unwrap_or(PredictResult {
359
+ confidence: 0.0,
360
+ keypoints: vec![],
361
+ scores: vec![],
362
+ });
363
 
364
  return Json(json!({
365
  "status": "success",
366
  "confidence": result.confidence,
367
  "keypoints": result.keypoints,
368
  "scores": result.scores,
369
+ "message": if result.confidence > 0.1 {
370
  format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
371
+ } else {
372
  "⚠️ Low confidence detection - no dart score could be verified.".to_string()
373
  }
374
  }));