Spaces:
Build error
Build error
kapil commited on
Commit ·
a6578e7
1
Parent(s): 9874885
update the code and md file
Browse files- README.md +119 -0
- src/data.rs +6 -2
- src/inference.rs +28 -11
- src/loss.rs +2 -1
- src/main.rs +1 -1
- src/model.rs +16 -16
- src/scoring.rs +20 -15
- src/server.rs +244 -51
README.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RustAutoScoreEngine
|
| 2 |
+
### High-Performance AI Dart Scoring Powered by Rust & Burn
|
| 3 |
+
|
| 4 |
+
<div align="center">
|
| 5 |
+
|
| 6 |
+
[](https://www.rust-lang.org/)
|
| 7 |
+
[](https://burn.dev/)
|
| 8 |
+
[](https://github.com/gfx-rs/wgpu)
|
| 9 |
+
[](LICENSE)
|
| 10 |
+
|
| 11 |
+
**A professional-grade, real-time dart scoring engine built entirely in Rust.**
|
| 12 |
+
|
| 13 |
+
Using the **Burn Deep Learning Framework**, this project achieves sub-millisecond inference and high-precision keypoint detection for automatic dart game tracking. The model optimization pipeline is built using modern Rust patterns for maximum safety and performance.
|
| 14 |
+
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Features
|
| 20 |
+
|
| 21 |
+
- **Optimized Inference**: Powered by Rust & WGPU for hardware-accelerated performance on Windows, Linux, and macOS.
|
| 22 |
+
- **Multi-Scale Keypoint Detection**: Enhanced YOLO-style heads for detecting dart tips and calibration corners.
|
| 23 |
+
- **BDO Logic Integrated**: Real-time sector calculation based on official board geometry and calibration symmetry.
|
| 24 |
+
- **Modern Web Dashboard**: Axum-based visual interface to monitor detections, scores, and latency in real-time.
|
| 25 |
+
- **Robust Calibration**: Automatic symmetry estimation to recover missing calibration points.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Dataset and Preparation
|
| 30 |
+
The model is trained on the primary dataset used for high-precision dart detection.
|
| 31 |
+
|
| 32 |
+
- **Download Link**: [Dataset Resources (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
|
| 33 |
+
- **Resolution**: 800x800 pre-cropped high-resolution images.
|
| 34 |
+
- **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
## Installation
|
| 39 |
+
|
| 40 |
+
### 1. Install Rust
|
| 41 |
+
If you do not have Rust installed, use the official installation script:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# Official Installation
|
| 45 |
+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 2. Clone and Build
|
| 49 |
+
```bash
|
| 50 |
+
git clone https://github.com/iambhabha/RustAutoScoreEngine.git
|
| 51 |
+
cd RustAutoScoreEngine
|
| 52 |
+
cargo build --release
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Quick Start Guide
|
| 58 |
+
|
| 59 |
+
### Step 1: Training the AI Model
|
| 60 |
+
To optimize the neural network for your local environment, run the training mode:
|
| 61 |
+
```bash
|
| 62 |
+
# Starts the training cycle (Configured for 50 Epochs)
|
| 63 |
+
cargo run
|
| 64 |
+
```
|
| 65 |
+
*Tip: Allow the loss to converge below 0.05 for optimal results.*
|
| 66 |
+
|
| 67 |
+
### Step 2: Running the Dashboard
|
| 68 |
+
After training is complete (generating the `model_weights.bin` file), launch the testing interface:
|
| 69 |
+
```bash
|
| 70 |
+
# Starts the Axum web server
|
| 71 |
+
cargo run -- gui
|
| 72 |
+
```
|
| 73 |
+
**Features:**
|
| 74 |
+
- **Image Upload**: Test local image samples via the dashboard.
|
| 75 |
+
- **Point Visualization**: Inspect detected calibration points and dart locations.
|
| 76 |
+
- **Automatic Scoring**: Instant sector calculation and latency reporting.
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Mobile Deployment
|
| 81 |
+
|
| 82 |
+
This engine is built on Burn, supporting multiple paths for Android and iOS integration:
|
| 83 |
+
|
| 84 |
+
### Path A: Native Rust
|
| 85 |
+
Package the engine as a library for direct hardware-accelerated execution on mobile targets.
|
| 86 |
+
- **Backend**: burn-wgpu with Vulkan (Android) or Metal (iOS).
|
| 87 |
+
- **Integration**: JNI (Android) or FFI (iOS) calls from native code.
|
| 88 |
+
|
| 89 |
+
### Path B: Weight Migration to TFLite/ONNX
|
| 90 |
+
- **TFLite**: Use the companion export scripts to generate a TensorFlow Lite bundle.
|
| 91 |
+
- **ONNX**: Utilize ONNX Runtime (ORT) for high-performance cross-platform execution.
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Technical Status and Contributing
|
| 96 |
+
|
| 97 |
+
> [!IMPORTANT]
|
| 98 |
+
> This project is currently in the experimental phase. We are actively refining the coordinate regression logic to ensure maximum precision across diverse board angles.
|
| 99 |
+
|
| 100 |
+
**Current Priorities:**
|
| 101 |
+
- Enhancing offset regression stability.
|
| 102 |
+
- Memory optimization for low-VRAM devices.
|
| 103 |
+
|
| 104 |
+
**Contribution Guidelines:**
|
| 105 |
+
If you encounter a bug or wish to provide performance optimizations, please submit a Pull Request.
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
## Resources
|
| 110 |
+
|
| 111 |
+
- **Original Inspiration**: [Paper: Keypoints as Objects for Automatic Scorekeeping](https://arxiv.org/abs/2105.09880)
|
| 112 |
+
- **Model Training Resources**: [Download from Google Drive](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
|
| 113 |
+
- **Official Documentation Reference**: [IEEE Dataport Dataset](https://ieee-dataport.org/open-access/deepdarts-dataset)
|
| 114 |
+
|
| 115 |
+
---
|
| 116 |
+
|
| 117 |
+
<div align="center">
|
| 118 |
+
Made by the Rust AI Community
|
| 119 |
+
</div>
|
src/data.rs
CHANGED
|
@@ -90,12 +90,16 @@ impl<B: Backend> DartBatcher<B> {
|
|
| 90 |
let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 91 |
let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
let cls = if i < 4 { i + 1 } else { 0 };
|
| 94 |
let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
|
| 95 |
|
| 96 |
// TF order: [x,y,w,h,obj,p0..p4]
|
| 97 |
-
target_raw[base_idx + 0 * grid_size * grid_size] =
|
| 98 |
-
target_raw[base_idx + 1 * grid_size * grid_size] =
|
| 99 |
target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
|
| 100 |
target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
|
| 101 |
target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
|
|
|
|
| 90 |
let gx = (p[0] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 91 |
let gy = (p[1] * grid_size as f32).floor().clamp(0.0, (grid_size - 1) as f32) as usize;
|
| 92 |
|
| 93 |
+
// Use Grid-Relative Coordinates (Relative to cell top-left)
|
| 94 |
+
let tx = p[0] * grid_size as f32 - gx as f32;
|
| 95 |
+
let ty = p[1] * grid_size as f32 - gy as f32;
|
| 96 |
+
|
| 97 |
let cls = if i < 4 { i + 1 } else { 0 };
|
| 98 |
let base_idx = (b_idx * num_channels * grid_size * grid_size) + (gy * grid_size) + gx;
|
| 99 |
|
| 100 |
// TF order: [x,y,w,h,obj,p0..p4]
|
| 101 |
+
target_raw[base_idx + 0 * grid_size * grid_size] = tx; // X (offset in cell)
|
| 102 |
+
target_raw[base_idx + 1 * grid_size * grid_size] = ty; // Y (offset in cell)
|
| 103 |
target_raw[base_idx + 2 * grid_size * grid_size] = 0.05; // W
|
| 104 |
target_raw[base_idx + 3 * grid_size * grid_size] = 0.05; // H
|
| 105 |
target_raw[base_idx + 4 * grid_size * grid_size] = 1.0; // Objectness (conf)
|
src/inference.rs
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
|
|
| 1 |
use burn::module::Module;
|
| 2 |
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 3 |
use burn::tensor::backend::Backend;
|
| 4 |
use burn::tensor::{Tensor, TensorData};
|
| 5 |
-
use
|
| 6 |
-
use image::{GenericImageView, DynamicImage};
|
| 7 |
|
| 8 |
pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
| 9 |
println!("🔍 Loading model for inference...");
|
| 10 |
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 11 |
let model: DartVisionModel<B> = DartVisionModel::new(device);
|
| 12 |
-
|
| 13 |
// Load weights
|
| 14 |
let record = Recorder::load(&recorder, "model_weights".into(), device)
|
| 15 |
.expect("Failed to load weights. Make sure model_weights.bin exists.");
|
|
@@ -18,12 +18,22 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 21 |
-
let pixels: Vec<f32> = resized
|
| 22 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
.collect();
|
| 24 |
|
| 25 |
let data = TensorData::new(pixels, [800, 800, 3]);
|
| 26 |
-
let input = Tensor::<B, 3>::from_data(data, device)
|
|
|
|
|
|
|
| 27 |
|
| 28 |
println!("🚀 Running MODEL Prediction...");
|
| 29 |
let (out16, _out32) = model.forward(input);
|
|
@@ -31,17 +41,24 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 31 |
// Post-process out16 (size [1, 30, 100, 100])
|
| 32 |
// Decode objectness part (Channel 4 for Anchor 0)
|
| 33 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 34 |
-
|
| 35 |
// Find highest confidence cell
|
| 36 |
let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
|
| 37 |
-
let confidence: f32 = max_val
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
println!("--------------------------------------------------");
|
| 40 |
println!("📊 RESULTS FOR: {}", image_path);
|
| 41 |
println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
|
| 42 |
-
|
| 43 |
if confidence > 0.05 {
|
| 44 |
-
println!(
|
|
|
|
|
|
|
|
|
|
| 45 |
} else {
|
| 46 |
println!("⚠️ Model confidence is too low. Training incomplete?");
|
| 47 |
}
|
|
|
|
| 1 |
+
use crate::model::DartVisionModel;
|
| 2 |
use burn::module::Module;
|
| 3 |
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 4 |
use burn::tensor::backend::Backend;
|
| 5 |
use burn::tensor::{Tensor, TensorData};
|
| 6 |
+
use image::{DynamicImage, GenericImageView};
|
|
|
|
| 7 |
|
| 8 |
pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
| 9 |
println!("🔍 Loading model for inference...");
|
| 10 |
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
|
| 11 |
let model: DartVisionModel<B> = DartVisionModel::new(device);
|
| 12 |
+
|
| 13 |
// Load weights
|
| 14 |
let record = Recorder::load(&recorder, "model_weights".into(), device)
|
| 15 |
.expect("Failed to load weights. Make sure model_weights.bin exists.");
|
|
|
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 21 |
+
let pixels: Vec<f32> = resized
|
| 22 |
+
.to_rgb8()
|
| 23 |
+
.pixels()
|
| 24 |
+
.flat_map(|p| {
|
| 25 |
+
vec![
|
| 26 |
+
p[0] as f32 / 255.0,
|
| 27 |
+
p[1] as f32 / 255.0,
|
| 28 |
+
p[2] as f32 / 255.0,
|
| 29 |
+
]
|
| 30 |
+
})
|
| 31 |
.collect();
|
| 32 |
|
| 33 |
let data = TensorData::new(pixels, [800, 800, 3]);
|
| 34 |
+
let input = Tensor::<B, 3>::from_data(data, device)
|
| 35 |
+
.unsqueeze::<4>()
|
| 36 |
+
.permute([0, 3, 1, 2]);
|
| 37 |
|
| 38 |
println!("🚀 Running MODEL Prediction...");
|
| 39 |
let (out16, _out32) = model.forward(input);
|
|
|
|
| 41 |
// Post-process out16 (size [1, 30, 100, 100])
|
| 42 |
// Decode objectness part (Channel 4 for Anchor 0)
|
| 43 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 44 |
+
|
| 45 |
// Find highest confidence cell
|
| 46 |
let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
|
| 47 |
+
let confidence: f32 = max_val
|
| 48 |
+
.to_data()
|
| 49 |
+
.convert::<f32>()
|
| 50 |
+
.as_slice::<f32>()
|
| 51 |
+
.unwrap()[0];
|
| 52 |
+
|
| 53 |
println!("--------------------------------------------------");
|
| 54 |
println!("📊 RESULTS FOR: {}", image_path);
|
| 55 |
println!("✨ Max Objectness: {:.2}%", confidence * 100.0);
|
| 56 |
+
|
| 57 |
if confidence > 0.05 {
|
| 58 |
+
println!(
|
| 59 |
+
"✅ Model found something! Confidence Score: {:.4}",
|
| 60 |
+
confidence
|
| 61 |
+
);
|
| 62 |
} else {
|
| 63 |
println!("⚠️ Model confidence is too low. Training incomplete?");
|
| 64 |
}
|
src/loss.rs
CHANGED
|
@@ -32,9 +32,10 @@ pub fn diou_loss<B: Backend>(
|
|
| 32 |
.mul_scalar(5.0); // Boost class learning
|
| 33 |
|
| 34 |
// 4. Coordinates (Channels 0-3) - Only learn when object exists
|
|
|
|
| 35 |
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
|
| 36 |
let b_xy_target = t.clone().narrow(2, 0, 2);
|
| 37 |
-
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(
|
| 38 |
|
| 39 |
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
|
| 40 |
let b_wh_target = t.clone().narrow(2, 2, 2);
|
|
|
|
| 32 |
.mul_scalar(5.0); // Boost class learning
|
| 33 |
|
| 34 |
// 4. Coordinates (Channels 0-3) - Only learn when object exists
|
| 35 |
+
// 2. Coordinate Loss (MSE on relative offsets) - Weighted x10 for precision
|
| 36 |
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
|
| 37 |
let b_xy_target = t.clone().narrow(2, 0, 2);
|
| 38 |
+
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(10.0);
|
| 39 |
|
| 40 |
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
|
| 41 |
let b_wh_target = t.clone().narrow(2, 2, 2);
|
src/main.rs
CHANGED
|
@@ -32,7 +32,7 @@ fn main() {
|
|
| 32 |
let dataset_path = "dataset/labels.json";
|
| 33 |
|
| 34 |
let config = TrainingConfig {
|
| 35 |
-
num_epochs:
|
| 36 |
batch_size: 1,
|
| 37 |
lr: 1e-3,
|
| 38 |
};
|
|
|
|
| 32 |
let dataset_path = "dataset/labels.json";
|
| 33 |
|
| 34 |
let config = TrainingConfig {
|
| 35 |
+
num_epochs: 50,
|
| 36 |
batch_size: 1,
|
| 37 |
lr: 1e-3,
|
| 38 |
};
|
src/model.rs
CHANGED
|
@@ -30,40 +30,40 @@ impl<B: Backend> ConvBlock<B> {
|
|
| 30 |
|
| 31 |
#[derive(Module, Debug)]
|
| 32 |
pub struct DartVisionModel<B: Backend> {
|
| 33 |
-
//
|
| 34 |
-
l1: ConvBlock<B>, // 3 ->
|
| 35 |
p1: MaxPool2d,
|
| 36 |
-
l2: ConvBlock<B>, //
|
| 37 |
p2: MaxPool2d,
|
| 38 |
-
l3: ConvBlock<B>, //
|
| 39 |
p3: MaxPool2d,
|
| 40 |
-
l4: ConvBlock<B>, //
|
| 41 |
p4: MaxPool2d,
|
| 42 |
-
l5: ConvBlock<B>, //
|
| 43 |
-
l6: ConvBlock<B>, //
|
| 44 |
|
| 45 |
-
head_32: Conv2d<B>, // Final detection head
|
| 46 |
}
|
| 47 |
|
| 48 |
impl<B: Backend> DartVisionModel<B> {
|
| 49 |
pub fn new(device: &B::Device) -> Self {
|
| 50 |
-
let l1 = ConvBlock::new(3,
|
| 51 |
let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 52 |
|
| 53 |
-
let l2 = ConvBlock::new(
|
| 54 |
let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 55 |
|
| 56 |
-
let l3 = ConvBlock::new(
|
| 57 |
let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 58 |
|
| 59 |
-
let l4 = ConvBlock::new(
|
| 60 |
let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 61 |
|
| 62 |
-
let l5 = ConvBlock::new(
|
| 63 |
-
let l6 = ConvBlock::new(
|
| 64 |
|
| 65 |
-
// 30 channels = 3 anchors * (x,y,w,h,obj,
|
| 66 |
-
let head_32 = Conv2dConfig::new([
|
| 67 |
|
| 68 |
Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
|
| 69 |
}
|
|
|
|
| 30 |
|
| 31 |
#[derive(Module, Debug)]
|
| 32 |
pub struct DartVisionModel<B: Backend> {
|
| 33 |
+
// Increased capacity: High resolution but enough width to map complex features
|
| 34 |
+
l1: ConvBlock<B>, // 3 -> 32
|
| 35 |
p1: MaxPool2d,
|
| 36 |
+
l2: ConvBlock<B>, // 32 -> 32
|
| 37 |
p2: MaxPool2d,
|
| 38 |
+
l3: ConvBlock<B>, // 32 -> 64
|
| 39 |
p3: MaxPool2d,
|
| 40 |
+
l4: ConvBlock<B>, // 64 -> 64
|
| 41 |
p4: MaxPool2d,
|
| 42 |
+
l5: ConvBlock<B>, // 64 -> 128
|
| 43 |
+
l6: ConvBlock<B>, // 128 -> 128
|
| 44 |
|
| 45 |
+
head_32: Conv2d<B>, // Final detection head (30 channels for 3 anchors)
|
| 46 |
}
|
| 47 |
|
| 48 |
impl<B: Backend> DartVisionModel<B> {
|
| 49 |
pub fn new(device: &B::Device) -> Self {
|
| 50 |
+
let l1 = ConvBlock::new(3, 32, [3, 3], device);
|
| 51 |
let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 52 |
|
| 53 |
+
let l2 = ConvBlock::new(32, 32, [3, 3], device);
|
| 54 |
let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 55 |
|
| 56 |
+
let l3 = ConvBlock::new(32, 64, [3, 3], device);
|
| 57 |
let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 58 |
|
| 59 |
+
let l4 = ConvBlock::new(64, 64, [3, 3], device);
|
| 60 |
let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 61 |
|
| 62 |
+
let l5 = ConvBlock::new(64, 128, [3, 3], device);
|
| 63 |
+
let l6 = ConvBlock::new(128, 128, [3, 3], device);
|
| 64 |
|
| 65 |
+
// 30 channels = 3 anchors * (x,y,w,h,obj,dart,cal1,cal2,cal3,cal4)
|
| 66 |
+
let head_32 = Conv2dConfig::new([128, 30], [1, 1]).init(device);
|
| 67 |
|
| 68 |
Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head_32 }
|
| 69 |
}
|
src/scoring.rs
CHANGED
|
@@ -23,44 +23,49 @@ impl Default for ScoringConfig {
|
|
| 23 |
pub fn get_board_dict() -> HashMap<i32, &'static str> {
|
| 24 |
let mut m = HashMap::new();
|
| 25 |
// BDO standard mapping based on degrees
|
| 26 |
-
let slices = [
|
|
|
|
|
|
|
|
|
|
| 27 |
for (i, &s) in slices.iter().enumerate() {
|
| 28 |
m.insert(i as i32, s);
|
| 29 |
}
|
| 30 |
m
|
| 31 |
}
|
| 32 |
|
| 33 |
-
pub fn calculate_dart_score(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
// 1. Calculate Center (Average of 4 calibration points)
|
| 35 |
let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
|
| 36 |
let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
|
| 37 |
|
| 38 |
// 2. Calculate average radius to boundary (doubles wire)
|
| 39 |
-
let avg_r_px = cal_pts
|
|
|
|
| 40 |
.map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
|
| 41 |
-
.sum::<f32>()
|
|
|
|
| 42 |
|
| 43 |
// 3. Relative distance of dart from center
|
| 44 |
let dx = dart_pt[0] - cx;
|
| 45 |
let dy = dart_pt[1] - cy;
|
| 46 |
let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
|
| 47 |
-
|
| 48 |
// Scale distance relative to BDO double radius
|
| 49 |
let dist_scaled = (dist_px / avg_r_px) * config.r_double;
|
| 50 |
|
| 51 |
// 4. Calculate Angle (0 is 3 o'clock, CCW)
|
| 52 |
let mut angle_deg = (-dy).atan2(dx).to_degrees();
|
| 53 |
-
if angle_deg < 0.0 {
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
//
|
| 58 |
-
// Each index is 18 deg. Offset = 4 * 18 = 72? No.
|
| 59 |
-
// Let's use the standard mapping: (angle / 18)
|
| 60 |
-
// Wait, the BOARD_DICT in Python uses int(angle / 18) where angle is 0-360.
|
| 61 |
-
// We need to match the slice orientation.
|
| 62 |
let board_dict = get_board_dict();
|
| 63 |
-
let sector_idx = ((angle_deg / 18.0).floor() as i32) % 20;
|
| 64 |
let sector_num = board_dict.get(§or_idx).unwrap_or(&"0");
|
| 65 |
|
| 66 |
// 5. Determine multipliers based on scaled distance
|
|
|
|
| 23 |
pub fn get_board_dict() -> HashMap<i32, &'static str> {
|
| 24 |
let mut m = HashMap::new();
|
| 25 |
// BDO standard mapping based on degrees
|
| 26 |
+
let slices = [
|
| 27 |
+
"6", "13", "4", "18", "1", "20", "5", "12", "9", "14", "11", "8", "16", "7", "19", "3",
|
| 28 |
+
"17", "2", "15", "10",
|
| 29 |
+
];
|
| 30 |
for (i, &s) in slices.iter().enumerate() {
|
| 31 |
m.insert(i as i32, s);
|
| 32 |
}
|
| 33 |
m
|
| 34 |
}
|
| 35 |
|
| 36 |
+
pub fn calculate_dart_score(
|
| 37 |
+
cal_pts: &[[f32; 2]],
|
| 38 |
+
dart_pt: &[f32; 2],
|
| 39 |
+
config: &ScoringConfig,
|
| 40 |
+
) -> (i32, String) {
|
| 41 |
// 1. Calculate Center (Average of 4 calibration points)
|
| 42 |
let cx = cal_pts.iter().map(|p| p[0]).sum::<f32>() / 4.0;
|
| 43 |
let cy = cal_pts.iter().map(|p| p[1]).sum::<f32>() / 4.0;
|
| 44 |
|
| 45 |
// 2. Calculate average radius to boundary (doubles wire)
|
| 46 |
+
let avg_r_px = cal_pts
|
| 47 |
+
.iter()
|
| 48 |
.map(|p| ((p[0] - cx).powi(2) + (p[1] - cy).powi(2)).sqrt())
|
| 49 |
+
.sum::<f32>()
|
| 50 |
+
/ 4.0;
|
| 51 |
|
| 52 |
// 3. Relative distance of dart from center
|
| 53 |
let dx = dart_pt[0] - cx;
|
| 54 |
let dy = dart_pt[1] - cy;
|
| 55 |
let dist_px = (dx.powi(2) + dy.powi(2)).sqrt();
|
| 56 |
+
|
| 57 |
// Scale distance relative to BDO double radius
|
| 58 |
let dist_scaled = (dist_px / avg_r_px) * config.r_double;
|
| 59 |
|
| 60 |
// 4. Calculate Angle (0 is 3 o'clock, CCW)
|
| 61 |
let mut angle_deg = (-dy).atan2(dx).to_degrees();
|
| 62 |
+
if angle_deg < 0.0 {
|
| 63 |
+
angle_deg += 360.0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// Center sectors by adding 9 degrees (half-sector width)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
let board_dict = get_board_dict();
|
| 68 |
+
let sector_idx = (((angle_deg + 9.0) / 18.0).floor() as i32) % 20;
|
| 69 |
let sector_num = board_dict.get(§or_idx).unwrap_or(&"0");
|
| 70 |
|
| 71 |
// 5. Determine multipliers based on scaled distance
|
src/server.rs
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
|
|
| 1 |
use axum::{
|
| 2 |
extract::{DefaultBodyLimit, Multipart, State},
|
| 3 |
response::{Html, Json},
|
| 4 |
routing::{get, post},
|
| 5 |
Router,
|
| 6 |
};
|
|
|
|
|
|
|
| 7 |
use burn::prelude::*;
|
| 8 |
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
| 9 |
-
use burn::backend::Wgpu;
|
| 10 |
-
use burn::backend::wgpu::WgpuDevice;
|
| 11 |
-
use crate::model::DartVisionModel;
|
| 12 |
use serde_json::json;
|
| 13 |
use std::net::SocketAddr;
|
| 14 |
use std::sync::Arc;
|
|
@@ -47,64 +47,235 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 47 |
let model = model.load_record(record);
|
| 48 |
|
| 49 |
while let Some(req) = rx.blocking_recv() {
|
|
|
|
|
|
|
| 50 |
let img = image::load_from_memory(&req.image_bytes).unwrap();
|
| 51 |
let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
|
| 52 |
-
let pixels: Vec<f32> = resized
|
| 53 |
-
.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
.collect();
|
| 55 |
|
| 56 |
let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
|
| 57 |
-
let input =
|
|
|
|
| 58 |
|
| 59 |
let (out16, _) = model.forward(input);
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
let mut final_points = vec![0.0f32; 8]; // 4 corners
|
| 62 |
let mut max_conf = 0.0f32;
|
| 63 |
|
| 64 |
-
//
|
| 65 |
-
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 66 |
-
|
| 67 |
-
// 2. Extract best calibration corner for each class 1 to 4 (Grid 26x26 = 676)
|
| 68 |
for cls_idx in 1..=4 {
|
| 69 |
-
let
|
| 70 |
-
let
|
| 71 |
-
let
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 74 |
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 75 |
-
|
| 76 |
let gx = f_idx % 26;
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
-
//
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
let (
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
final_points.push(dx);
|
| 103 |
-
final_points.push(dy);
|
| 104 |
}
|
| 105 |
|
| 106 |
let mut final_scores = vec![];
|
| 107 |
-
|
| 108 |
// Calculate scores if we have calibration points and at least one dart
|
| 109 |
if final_points.len() >= 10 {
|
| 110 |
use crate::scoring::{calculate_dart_score, ScoringConfig};
|
|
@@ -115,22 +286,32 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 115 |
[final_points[4], final_points[5]],
|
| 116 |
[final_points[6], final_points[7]],
|
| 117 |
];
|
| 118 |
-
|
| 119 |
for dart_chunk in final_points[8..].chunks(2) {
|
| 120 |
if dart_chunk.len() == 2 {
|
| 121 |
let dart_pt = [dart_chunk[0], dart_chunk[1]];
|
| 122 |
-
let (
|
| 123 |
-
final_scores.push(label);
|
|
|
|
| 124 |
}
|
| 125 |
}
|
| 126 |
}
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
| 129 |
let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
|
| 130 |
for (i, pts) in final_points.chunks(2).enumerate() {
|
| 131 |
let name = class_names.get(i).unwrap_or(&"Dart");
|
| 132 |
-
let label = final_scores
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
}
|
| 135 |
|
| 136 |
let _ = req.response_tx.send(PredictResult {
|
|
@@ -144,7 +325,10 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 144 |
let state = Arc::new(tx);
|
| 145 |
|
| 146 |
let app = Router::new()
|
| 147 |
-
.route(
|
|
|
|
|
|
|
|
|
|
| 148 |
.route("/api/predict", post(predict_handler))
|
| 149 |
.with_state(state)
|
| 150 |
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
|
|
@@ -165,17 +349,26 @@ async fn predict_handler(
|
|
| 165 |
Err(_) => continue,
|
| 166 |
};
|
| 167 |
let (res_tx, res_rx) = oneshot::channel();
|
| 168 |
-
let _ = tx
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
return Json(json!({
|
| 172 |
"status": "success",
|
| 173 |
"confidence": result.confidence,
|
| 174 |
"keypoints": result.keypoints,
|
| 175 |
"scores": result.scores,
|
| 176 |
-
"message": if result.confidence > 0.1 {
|
| 177 |
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
|
| 178 |
-
} else {
|
| 179 |
"⚠️ Low confidence detection - no dart score could be verified.".to_string()
|
| 180 |
}
|
| 181 |
}));
|
|
|
|
| 1 |
+
use crate::model::DartVisionModel;
|
| 2 |
use axum::{
|
| 3 |
extract::{DefaultBodyLimit, Multipart, State},
|
| 4 |
response::{Html, Json},
|
| 5 |
routing::{get, post},
|
| 6 |
Router,
|
| 7 |
};
|
| 8 |
+
use burn::backend::wgpu::WgpuDevice;
|
| 9 |
+
use burn::backend::Wgpu;
|
| 10 |
use burn::prelude::*;
|
| 11 |
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
|
|
|
|
|
|
|
|
|
| 12 |
use serde_json::json;
|
| 13 |
use std::net::SocketAddr;
|
| 14 |
use std::sync::Arc;
|
|
|
|
| 47 |
let model = model.load_record(record);
|
| 48 |
|
| 49 |
while let Some(req) = rx.blocking_recv() {
|
| 50 |
+
let start_time = std::time::Instant::now();
|
| 51 |
+
|
| 52 |
let img = image::load_from_memory(&req.image_bytes).unwrap();
|
| 53 |
let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
|
| 54 |
+
let pixels: Vec<f32> = resized
|
| 55 |
+
.to_rgb8()
|
| 56 |
+
.pixels()
|
| 57 |
+
.flat_map(|p| {
|
| 58 |
+
vec![
|
| 59 |
+
p[0] as f32 / 255.0,
|
| 60 |
+
p[1] as f32 / 255.0,
|
| 61 |
+
p[2] as f32 / 255.0,
|
| 62 |
+
]
|
| 63 |
+
})
|
| 64 |
.collect();
|
| 65 |
|
| 66 |
let tensor_data = TensorData::new(pixels, [1, 416, 416, 3]);
|
| 67 |
+
let input =
|
| 68 |
+
Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
|
| 69 |
|
| 70 |
let (out16, _) = model.forward(input);
|
| 71 |
+
|
| 72 |
+
// 1. Reshape to separate anchors: [1, 3, 10, 26, 26]
|
| 73 |
+
let out_reshaped = out16.reshape([1, 3, 10, 26, 26]);
|
| 74 |
+
|
| 75 |
+
// 1.5 Debug: Raw Statistics
|
| 76 |
+
println!(
|
| 77 |
+
"🔍 [Model Stats] Raw Min: {:.4}, Max: {:.4}",
|
| 78 |
+
out_reshaped.clone().min().into_scalar(),
|
| 79 |
+
out_reshaped.clone().max().into_scalar()
|
| 80 |
+
);
|
| 81 |
+
|
| 82 |
let mut final_points = vec![0.0f32; 8]; // 4 corners
|
| 83 |
let mut max_conf = 0.0f32;
|
| 84 |
|
| 85 |
+
// 2. Extract best calibration corner for each class 1 to 4
|
|
|
|
|
|
|
|
|
|
| 86 |
for cls_idx in 1..=4 {
|
| 87 |
+
let mut best_s = -1.0f32;
|
| 88 |
+
let mut best_pt = [0.0f32; 2];
|
| 89 |
+
let mut best_anchor = 0;
|
| 90 |
+
let mut best_grid = (0, 0);
|
| 91 |
+
|
| 92 |
+
for anchor in 0..3 {
|
| 93 |
+
let obj = burn::tensor::activation::sigmoid(
|
| 94 |
+
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
|
| 95 |
+
);
|
| 96 |
+
let prob = burn::tensor::activation::sigmoid(
|
| 97 |
+
out_reshaped
|
| 98 |
+
.clone()
|
| 99 |
+
.narrow(1, anchor, 1)
|
| 100 |
+
.narrow(2, 5 + cls_idx, 1),
|
| 101 |
+
);
|
| 102 |
+
let score = obj.mul(prob);
|
| 103 |
+
|
| 104 |
+
let (val, idx) = score.reshape([1, 676]).max_dim_with_indices(1);
|
| 105 |
+
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 106 |
+
if s > best_s {
|
| 107 |
+
best_s = s;
|
| 108 |
+
best_anchor = anchor;
|
| 109 |
+
let f_idx =
|
| 110 |
+
idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 111 |
+
best_grid = (f_idx % 26, f_idx / 26);
|
| 112 |
+
|
| 113 |
+
let sx = burn::tensor::activation::sigmoid(
|
| 114 |
+
out_reshaped
|
| 115 |
+
.clone()
|
| 116 |
+
.narrow(1, anchor, 1)
|
| 117 |
+
.narrow(2, 0, 1)
|
| 118 |
+
.slice([
|
| 119 |
+
0..1,
|
| 120 |
+
0..1,
|
| 121 |
+
0..1,
|
| 122 |
+
best_grid.1..best_grid.1 + 1,
|
| 123 |
+
best_grid.0..best_grid.0 + 1,
|
| 124 |
+
]),
|
| 125 |
+
)
|
| 126 |
+
.to_data()
|
| 127 |
+
.convert::<f32>()
|
| 128 |
+
.as_slice::<f32>()
|
| 129 |
+
.unwrap()[0];
|
| 130 |
+
let sy = burn::tensor::activation::sigmoid(
|
| 131 |
+
out_reshaped
|
| 132 |
+
.clone()
|
| 133 |
+
.narrow(1, anchor, 1)
|
| 134 |
+
.narrow(2, 1, 1)
|
| 135 |
+
.slice([
|
| 136 |
+
0..1,
|
| 137 |
+
0..1,
|
| 138 |
+
0..1,
|
| 139 |
+
best_grid.1..best_grid.1 + 1,
|
| 140 |
+
best_grid.0..best_grid.0 + 1,
|
| 141 |
+
]),
|
| 142 |
+
)
|
| 143 |
+
.to_data()
|
| 144 |
+
.convert::<f32>()
|
| 145 |
+
.as_slice::<f32>()
|
| 146 |
+
.unwrap()[0];
|
| 147 |
+
|
| 148 |
+
// Reconstruct Absolute Normalized Coord (0-1)
|
| 149 |
+
best_pt = [
|
| 150 |
+
(best_grid.0 as f32 + sx) / 26.0,
|
| 151 |
+
(best_grid.1 as f32 + sy) / 26.0,
|
| 152 |
+
];
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
final_points[(cls_idx - 1) * 2] = best_pt[0];
|
| 157 |
+
final_points[(cls_idx - 1) * 2 + 1] = best_pt[1];
|
| 158 |
+
if best_s > max_conf {
|
| 159 |
+
max_conf = best_s;
|
| 160 |
+
}
|
| 161 |
+
println!(
|
| 162 |
+
" [Debug Cal{}] Anchor: {}, Conf: {:.4}, Cell: {:?}, Coord: [{:.3}, {:.3}]",
|
| 163 |
+
cls_idx, best_anchor, best_s, best_grid, best_pt[0], best_pt[1]
|
| 164 |
+
);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
// 3. Calibration Estimation (Python logic: est_cal_pts)
|
| 168 |
+
// If one calibration point is missing, estimate it using symmetry
|
| 169 |
+
let mut valid_cal_count = 0;
|
| 170 |
+
let mut missing_idx = -1;
|
| 171 |
+
for i in 0..4 {
|
| 172 |
+
if final_points[i*2] > 0.01 || final_points[i*2+1] > 0.01 {
|
| 173 |
+
valid_cal_count += 1;
|
| 174 |
+
} else {
|
| 175 |
+
missing_idx = i as i32;
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
if valid_cal_count == 3 {
|
| 180 |
+
println!("⚠️ [Calibration Recovery] Estimating missing point Cal{}...", missing_idx + 1);
|
| 181 |
+
match missing_idx {
|
| 182 |
+
0 | 1 => { // Top points missing, use bottom points center
|
| 183 |
+
let cx = (final_points[4] + final_points[6]) / 2.0;
|
| 184 |
+
let cy = (final_points[5] + final_points[7]) / 2.0;
|
| 185 |
+
if missing_idx == 0 {
|
| 186 |
+
final_points[0] = 2.0 * cx - final_points[2];
|
| 187 |
+
final_points[1] = 2.0 * cy - final_points[3];
|
| 188 |
+
} else {
|
| 189 |
+
final_points[2] = 2.0 * cx - final_points[0];
|
| 190 |
+
final_points[3] = 2.0 * cy - final_points[1];
|
| 191 |
+
}
|
| 192 |
+
},
|
| 193 |
+
2 | 3 => { // Bottom points missing, use top points center
|
| 194 |
+
let cx = (final_points[0] + final_points[2]) / 2.0;
|
| 195 |
+
let cy = (final_points[1] + final_points[3]) / 2.0;
|
| 196 |
+
if missing_idx == 2 {
|
| 197 |
+
final_points[4] = 2.0 * cx - final_points[6];
|
| 198 |
+
final_points[5] = 2.0 * cy - final_points[7];
|
| 199 |
+
} else {
|
| 200 |
+
final_points[6] = 2.0 * cx - final_points[4];
|
| 201 |
+
final_points[7] = 2.0 * cy - final_points[5];
|
| 202 |
+
}
|
| 203 |
+
},
|
| 204 |
+
_ => {}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// 4. Extract best dart (Class 0) - Find candidates across all anchors
|
| 209 |
+
println!(" [Debug Dart] Searching for Candidates...");
|
| 210 |
+
let mut dart_candidates = vec![];
|
| 211 |
+
for anchor in 0..3 {
|
| 212 |
+
let obj = burn::tensor::activation::sigmoid(
|
| 213 |
+
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
|
| 214 |
+
);
|
| 215 |
+
let prob = burn::tensor::activation::sigmoid(
|
| 216 |
+
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1),
|
| 217 |
+
);
|
| 218 |
+
let score = obj.mul(prob).reshape([1, 676]);
|
| 219 |
+
|
| 220 |
+
let (val, idx) = score.max_dim_with_indices(1);
|
| 221 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 222 |
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 223 |
+
|
| 224 |
let gx = f_idx % 26;
|
| 225 |
+
let gy = f_idx / 26;
|
| 226 |
|
| 227 |
+
let dsx = burn::tensor::activation::sigmoid(
|
| 228 |
+
out_reshaped
|
| 229 |
+
.clone()
|
| 230 |
+
.narrow(1, anchor, 1)
|
| 231 |
+
.narrow(2, 0, 1)
|
| 232 |
+
.slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
|
| 233 |
+
)
|
| 234 |
+
.to_data()
|
| 235 |
+
.convert::<f32>()
|
| 236 |
+
.as_slice::<f32>()
|
| 237 |
+
.unwrap()[0];
|
| 238 |
+
let dsy = burn::tensor::activation::sigmoid(
|
| 239 |
+
out_reshaped
|
| 240 |
+
.clone()
|
| 241 |
+
.narrow(1, anchor, 1)
|
| 242 |
+
.narrow(2, 1, 1)
|
| 243 |
+
.slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
|
| 244 |
+
)
|
| 245 |
+
.to_data()
|
| 246 |
+
.convert::<f32>()
|
| 247 |
+
.as_slice::<f32>()
|
| 248 |
+
.unwrap()[0];
|
| 249 |
|
| 250 |
+
let dx = (gx as f32 + dsx) / 26.0;
|
| 251 |
+
let dy = (gy as f32 + dsy) / 26.0;
|
| 252 |
+
|
| 253 |
+
if s > 0.005 {
|
| 254 |
+
println!(
|
| 255 |
+
" - A{} Best Cell: ({},{}), Conf: {:.4}, Coord: [{:.3}, {:.3}]",
|
| 256 |
+
anchor, gx, gy, s, dx, dy
|
| 257 |
+
);
|
| 258 |
+
dart_candidates.push((s, [dx, dy]));
|
| 259 |
+
}
|
| 260 |
}
|
| 261 |
|
| 262 |
+
// Pick the best dart candidate among all anchors
|
| 263 |
+
dart_candidates
|
| 264 |
+
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
| 265 |
+
if let Some((s, pt)) = dart_candidates.first() {
|
| 266 |
+
if *s > 0.05 {
|
| 267 |
+
final_points.push(pt[0]);
|
| 268 |
+
final_points.push(pt[1]);
|
| 269 |
+
println!(
|
| 270 |
+
" ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}",
|
| 271 |
+
s * 100.0,
|
| 272 |
+
pt
|
| 273 |
+
);
|
| 274 |
+
}
|
|
|
|
|
|
|
| 275 |
}
|
| 276 |
|
| 277 |
let mut final_scores = vec![];
|
| 278 |
+
|
| 279 |
// Calculate scores if we have calibration points and at least one dart
|
| 280 |
if final_points.len() >= 10 {
|
| 281 |
use crate::scoring::{calculate_dart_score, ScoringConfig};
|
|
|
|
| 286 |
[final_points[4], final_points[5]],
|
| 287 |
[final_points[6], final_points[7]],
|
| 288 |
];
|
| 289 |
+
|
| 290 |
for dart_chunk in final_points[8..].chunks(2) {
|
| 291 |
if dart_chunk.len() == 2 {
|
| 292 |
let dart_pt = [dart_chunk[0], dart_chunk[1]];
|
| 293 |
+
let (s_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
|
| 294 |
+
final_scores.push(label.clone());
|
| 295 |
+
println!(" [Debug Score] Label: {} (Val: {})", label, s_val);
|
| 296 |
}
|
| 297 |
}
|
| 298 |
}
|
| 299 |
|
| 300 |
+
let duration = start_time.elapsed();
|
| 301 |
+
println!("⚡ [Inference Performance] Total Latency: {:.2?}", duration);
|
| 302 |
+
|
| 303 |
+
println!("🎯 [Final Result] Top Confidence: {:.2}%", max_conf * 100.0);
|
| 304 |
let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
|
| 305 |
for (i, pts) in final_points.chunks(2).enumerate() {
|
| 306 |
let name = class_names.get(i).unwrap_or(&"Dart");
|
| 307 |
+
let label = final_scores
|
| 308 |
+
.get(i.saturating_sub(4))
|
| 309 |
+
.cloned()
|
| 310 |
+
.unwrap_or_default();
|
| 311 |
+
println!(
|
| 312 |
+
" - {}: [x: {:.3}, y: {:.3}] {}",
|
| 313 |
+
name, pts[0], pts[1], label
|
| 314 |
+
);
|
| 315 |
}
|
| 316 |
|
| 317 |
let _ = req.response_tx.send(PredictResult {
|
|
|
|
| 325 |
let state = Arc::new(tx);
|
| 326 |
|
| 327 |
let app = Router::new()
|
| 328 |
+
.route(
|
| 329 |
+
"/",
|
| 330 |
+
get(|| async { Html(include_str!("../static/index.html")) }),
|
| 331 |
+
)
|
| 332 |
.route("/api/predict", post(predict_handler))
|
| 333 |
.with_state(state)
|
| 334 |
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
|
|
|
|
| 349 |
Err(_) => continue,
|
| 350 |
};
|
| 351 |
let (res_tx, res_rx) = oneshot::channel();
|
| 352 |
+
let _ = tx
|
| 353 |
+
.send(PredictRequest {
|
| 354 |
+
image_bytes: bytes,
|
| 355 |
+
response_tx: res_tx,
|
| 356 |
+
})
|
| 357 |
+
.await;
|
| 358 |
+
let result = res_rx.await.unwrap_or(PredictResult {
|
| 359 |
+
confidence: 0.0,
|
| 360 |
+
keypoints: vec![],
|
| 361 |
+
scores: vec![],
|
| 362 |
+
});
|
| 363 |
|
| 364 |
return Json(json!({
|
| 365 |
"status": "success",
|
| 366 |
"confidence": result.confidence,
|
| 367 |
"keypoints": result.keypoints,
|
| 368 |
"scores": result.scores,
|
| 369 |
+
"message": if result.confidence > 0.1 {
|
| 370 |
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
|
| 371 |
+
} else {
|
| 372 |
"⚠️ Low confidence detection - no dart score could be verified.".to_string()
|
| 373 |
}
|
| 374 |
}));
|