Spaces:
Build error
Build error
kapil commited on
Commit ·
8e03aff
1
Parent(s): 850d827
feat: implement core RustAutoScoreEngine framework including data loading, model architecture, training loop, inference, and GUI server
Browse files- .gitignore +13 -23
- README.md +25 -126
- model_weights.bin +3 -0
- src/data.rs +39 -19
- src/inference.rs +6 -6
- src/main.rs +3 -3
- src/model.rs +43 -40
- src/server.rs +25 -13
- src/tests.rs +4 -4
- src/train.rs +3 -2
- static/index.html +31 -12
.gitignore
CHANGED
|
@@ -1,27 +1,17 @@
|
|
| 1 |
# Rust
|
| 2 |
-
target/
|
| 3 |
-
|
| 4 |
-
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
*.swo
|
| 11 |
-
*~
|
| 12 |
-
.DS_Store
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
dataset/images/
|
| 18 |
-
dataset/cropped_images/
|
| 19 |
-
dataset/800/
|
| 20 |
-
dataset/__pycache__/
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
tmp/
|
| 26 |
-
temp/
|
| 27 |
-
*.tmp
|
|
|
|
| 1 |
# Rust
|
| 2 |
+
/target/
|
| 3 |
+
**/*.rs.bk
|
| 4 |
+
Cargo.lock
|
| 5 |
|
| 6 |
+
# Dataset (DONT COMMIT Large 16K+ Images)
|
| 7 |
+
/dataset/
|
| 8 |
+
/dataset/*/
|
| 9 |
+
!/dataset/labels.json
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
# Operating System
|
| 12 |
+
.DS_Store
|
| 13 |
+
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
# Optional: You can keep model_weights.bin if it's small (~1.1 MB)
|
| 16 |
+
# to let others use the GUI immediately.
|
| 17 |
+
# model_weights.bin
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,144 +1,43 @@
|
|
| 1 |
-
#
|
| 2 |
-
### High-Performance AI Dart Scoring Powered by Rust & Burn
|
| 3 |
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
[](https://burn.dev/)
|
| 8 |
-
[](https://github.com/gfx-rs/wgpu)
|
| 9 |
-
[](LICENSE)
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
Using the **Burn Deep Learning Framework**, this project achieves sub-millisecond inference and high-precision keypoint detection for automatic dart game tracking. The model optimization pipeline is built using modern Rust patterns for maximum safety and performance.
|
| 14 |
-
|
| 15 |
-
</div>
|
| 16 |
-
|
| 17 |
-
---
|
| 18 |
-
|
| 19 |
-
## Features
|
| 20 |
-
|
| 21 |
-
- **Optimized Inference**: Powered by Rust & WGPU for hardware-accelerated performance on Windows, Linux, and macOS.
|
| 22 |
-
- **Multi-Scale Keypoint Detection**: Enhanced YOLO-style heads for detecting dart tips and calibration corners.
|
| 23 |
-
- **BDO Logic Integrated**: Real-time sector calculation based on official board geometry and calibration symmetry.
|
| 24 |
-
- **Modern Web Dashboard**: Axum-based visual interface to monitor detections, scores, and latency in real-time.
|
| 25 |
-
- **Robust Calibration**: Automatic symmetry estimation to recover missing calibration points.
|
| 26 |
-
|
| 27 |
-
---
|
| 28 |
-
|
| 29 |
-
## Dataset and Preparation
|
| 30 |
-
The model is trained on the primary dataset used for high-precision dart detection.
|
| 31 |
-
|
| 32 |
-
- **Model Weights Link**: [Neural Weights & TFLite (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
|
| 33 |
-
- **Dataset Source**: [DeepDarts (IEEE Dataport)](https://ieee-dataport.org/open-access/deepdarts-dataset)
|
| 34 |
-
- **Resolution**: 800x800 pre-cropped high-resolution images.
|
| 35 |
-
- **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
|
| 36 |
-
|
| 37 |
-
---
|
| 38 |
-
|
| 39 |
-
## Installation
|
| 40 |
-
|
| 41 |
-
### 1. Install Rust
|
| 42 |
-
If you do not have Rust installed, use the official installation script:
|
| 43 |
-
|
| 44 |
-
```bash
|
| 45 |
-
# Official Installation
|
| 46 |
-
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
### 2. Clone and Build
|
| 50 |
```bash
|
| 51 |
-
|
| 52 |
-
cd RustAutoScoreEngine
|
| 53 |
-
cargo build --release
|
| 54 |
```
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
## Quick Start Guide
|
| 59 |
|
| 60 |
-
### Step 1: Training the AI Model
|
| 61 |
-
To optimize the neural network for your local environment, run the training mode:
|
| 62 |
```bash
|
| 63 |
-
|
| 64 |
-
cargo run -- train
|
| 65 |
```
|
|
|
|
| 66 |
|
| 67 |
-
##
|
| 68 |
-
|
| 69 |
-
```bash
|
| 70 |
-
# Starts the modular Axum web server
|
| 71 |
-
cargo run -- gui
|
| 72 |
-
```
|
| 73 |
-
**Features:**
|
| 74 |
-
- **Dynamic Image Upload**: Test board imagery via the premium glassmorphism dashboard.
|
| 75 |
-
- **Neural Point Mapping**: Inspect detected calibration corners and dart locations with hover effects.
|
| 76 |
-
- **Real-time Scoring**: Instant sector calculation based on official BDO geometry.
|
| 77 |
|
| 78 |
-
### Step 3: CLI Model Testing
|
| 79 |
-
Test individual images directly from the terminal:
|
| 80 |
```bash
|
| 81 |
-
|
| 82 |
-
cargo run -- test path/to/image.jpg
|
| 83 |
```
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
-
##
|
| 92 |
-
|
| 93 |
-
-
|
| 94 |
-
-
|
| 95 |
-
|
| 96 |
-
### Path B: Weight Migration to TFLite/ONNX
|
| 97 |
-
- **TFLite**: Use the companion export scripts to generate a TensorFlow Lite bundle.
|
| 98 |
-
- **ONNX**: Utilize ONNX Runtime (ORT) for high-performance cross-platform execution.
|
| 99 |
|
| 100 |
---
|
| 101 |
-
|
| 102 |
-
## Hardware Optimization
|
| 103 |
-
|
| 104 |
-
This engine is optimized for GPU execution using the WGPU backend. Depending on your specific hardware, you may need to adjust the training intensity:
|
| 105 |
-
|
| 106 |
-
### GPU VRAM Management
|
| 107 |
-
If you encounter **Out-of-Memory (OOM)** errors during training, you should reduce the **Batch Size**.
|
| 108 |
-
|
| 109 |
-
- **Where to change**: Open `src/main.rs` and modify the `batch_size` parameter.
|
| 110 |
-
- **Recommendations**:
|
| 111 |
-
- **4GB VRAM**: Batch Size 1 (Safe default)
|
| 112 |
-
- **8GB VRAM**: Batch Size 4
|
| 113 |
-
- **12GB+ VRAM**: Batch Size 8
|
| 114 |
-
- **RTX 5080 High-End**: Batch Size 16 (Optimal for ultra-fast convergence)
|
| 115 |
-
- **Impact**: Larger batch sizes provide more stable gradients but require exponentially more VRAM.
|
| 116 |
-
|
| 117 |
-
---
|
| 118 |
-
|
| 119 |
-
## Technical Status and Contributing
|
| 120 |
-
|
| 121 |
-
> [!IMPORTANT]
|
| 122 |
-
> This project is currently in the experimental phase. We are actively refining the coordinate regression logic to ensure maximum precision across diverse board angles.
|
| 123 |
-
|
| 124 |
-
**Current Priorities:**
|
| 125 |
-
- Enhancing offset regression stability.
|
| 126 |
-
- Memory optimization for low-VRAM devices.
|
| 127 |
-
|
| 128 |
-
**Contribution Guidelines:**
|
| 129 |
-
If you encounter a bug or wish to provide performance optimizations, please submit a Pull Request.
|
| 130 |
-
|
| 131 |
-
---
|
| 132 |
-
|
| 133 |
-
## Resources
|
| 134 |
-
|
| 135 |
-
- **Core AI Framework**: [Burn - A Flexible & Comprehensive Deep Learning Framework](https://burn.dev/)
|
| 136 |
-
- **Original Inspiration**: [Paper: Keypoints as Objects for Automatic Scorekeeping](https://arxiv.org/abs/2105.09880)
|
| 137 |
-
- **Model Training Resources**: [Download from Google Drive](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
|
| 138 |
-
- **Official Documentation Reference**: [IEEE Dataport Dataset](https://ieee-dataport.org/open-access/deepdarts-dataset)
|
| 139 |
-
|
| 140 |
-
---
|
| 141 |
-
|
| 142 |
-
<div align="center">
|
| 143 |
-
Made by the Rust AI Community
|
| 144 |
-
</div>
|
|
|
|
| 1 |
+
# 🎯 DartVision AI - Rust AutoScore Engine
|
|
|
|
| 2 |
|
| 3 |
+
A high-performance dart scoring system built with **Rust** and the **Burn** Deep Learning framework. This project is a port of the original YOLOv4-tiny based DartVision, optimized for speed and safety.
|
| 4 |
|
| 5 |
+

|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
## 🚀 Quick Start (GUI Dashboard)
|
| 8 |
+
The project comes with pre-trained weights (`model_weights.bin`). You can start the professional dashboard immediately:
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
```bash
|
| 11 |
+
cargo run --release -- gui
|
|
|
|
|
|
|
| 12 |
```
|
| 13 |
+
Then open: **[http://127.0.0.1:8080](http://127.0.0.1:8080)**
|
| 14 |
|
| 15 |
+
## 📈 Training
|
| 16 |
+
To train the model on your own dataset (requires `dataset/labels.json` and images):
|
|
|
|
| 17 |
|
|
|
|
|
|
|
| 18 |
```bash
|
| 19 |
+
cargo run --release -- train
|
|
|
|
| 20 |
```
|
| 21 |
+
*Note: The model saves checkpoints every 100 batches. You can stop and resume training anytime.*
|
| 22 |
|
| 23 |
+
## 🔬 Testing
|
| 24 |
+
To run a single image inference and see the neural mapping results:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
| 26 |
```bash
|
| 27 |
+
cargo run --release -- test <path_to_image>
|
|
|
|
| 28 |
```
|
| 29 |
|
| 30 |
+
## ✨ Features
|
| 31 |
+
- **Neural Mapping:** Real-time detection of darts and 4 calibration corners.
|
| 32 |
+
- **Smart Scoring:** Automatic coordinate reconstruction and BDO standard scoring.
|
| 33 |
+
- **Reliability Checks:** GUI displays per-point confidence percentages (CAL Sync) to ensure accuracy.
|
| 34 |
+
- **GPU Accelerated:** Powered by `WGPUDevice` and `Burn` for ultra-fast inference.
|
| 35 |
|
| 36 |
+
## 🛠 Project Structure
|
| 37 |
+
- `src/model.rs`: YOLOv4-tiny architecture in Burn.
|
| 38 |
+
- `src/loss.rs`: DIOU Loss + Objectness + Class entropy implementation.
|
| 39 |
+
- `src/server.rs`: Axum-based web server for the GUI.
|
| 40 |
+
- `static/index.html`: Premium Glassmorphism interface with SVG overlays.
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
---
|
| 43 |
+
*Created by [iambhabha](https://github.com/iambhabha)*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_weights.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aa0641e49b7b65cd2e4529cb89f52cba057f42dd6bdfd4044ac7c5aced492ef5
|
| 3 |
+
size 1171656
|
src/data.rs
CHANGED
|
@@ -69,17 +69,24 @@ impl<B: Backend> DartBatcher<B> {
|
|
| 69 |
|
| 70 |
pub fn batch_manual(&self, items: Vec<Annotation>) -> DartBatch<B> {
|
| 71 |
let batch_size = items.len();
|
| 72 |
-
|
| 73 |
-
let
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
let mut images_list = Vec::with_capacity(batch_size);
|
| 77 |
let mut target_raw = vec![0.0f32; batch_size * num_channels * grid_size * grid_size];
|
| 78 |
|
| 79 |
for (b_idx, item) in items.iter().enumerate() {
|
| 80 |
// 1. Process Image
|
| 81 |
let path = format!("dataset/800/{}/{}", item.img_folder, item.img_name);
|
| 82 |
-
let img = image::open(&path).unwrap_or_else(|_|
|
|
|
|
|
|
|
|
|
|
| 83 |
let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
|
| 84 |
let pixels: Vec<f32> = resized.to_rgb8().pixels()
|
| 85 |
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
|
|
@@ -87,23 +94,36 @@ impl<B: Backend> DartBatcher<B> {
|
|
| 87 |
images_list.push(TensorData::new(pixels, [input_res, input_res, 3]));
|
| 88 |
|
| 89 |
for (i, p) in item.xy.iter().enumerate() {
|
| 90 |
-
|
| 91 |
-
let
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
let
|
| 95 |
-
let ty = p[1] * grid_size as f32 - gy as f32;
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
let cls = if i < 4 { i + 1 } else { 0 };
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
}
|
| 108 |
}
|
| 109 |
|
|
|
|
| 69 |
|
| 70 |
pub fn batch_manual(&self, items: Vec<Annotation>) -> DartBatch<B> {
|
| 71 |
let batch_size = items.len();
|
| 72 |
+
// Use 800 to match original Python training config (configs/deepdarts_d1.yaml: input_size: 800)
|
| 73 |
+
let input_res: usize = 800;
|
| 74 |
+
// For tiny YOLO: grid = input_res / 16. 800/16 = 50
|
| 75 |
+
let grid_size: usize = 50;
|
| 76 |
+
let num_anchors: usize = 3;
|
| 77 |
+
let num_attrs: usize = 10; // x, y, w, h, obj, cls0..cls4
|
| 78 |
+
let num_channels: usize = num_anchors * num_attrs; // = 30
|
| 79 |
+
|
| 80 |
let mut images_list = Vec::with_capacity(batch_size);
|
| 81 |
let mut target_raw = vec![0.0f32; batch_size * num_channels * grid_size * grid_size];
|
| 82 |
|
| 83 |
for (b_idx, item) in items.iter().enumerate() {
|
| 84 |
// 1. Process Image
|
| 85 |
let path = format!("dataset/800/{}/{}", item.img_folder, item.img_name);
|
| 86 |
+
let img = image::open(&path).unwrap_or_else(|_| {
|
| 87 |
+
println!("⚠️ [Data] Image not found: {}", path);
|
| 88 |
+
image::DynamicImage::new_rgb8(input_res as u32, input_res as u32)
|
| 89 |
+
});
|
| 90 |
let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
|
| 91 |
let pixels: Vec<f32> = resized.to_rgb8().pixels()
|
| 92 |
.flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
|
|
|
|
| 94 |
images_list.push(TensorData::new(pixels, [input_res, input_res, 3]));
|
| 95 |
|
| 96 |
for (i, p) in item.xy.iter().enumerate() {
|
| 97 |
+
// Clamp coordinates to valid grid range
|
| 98 |
+
let norm_x = p[0].clamp(0.0, 1.0 - 1e-5);
|
| 99 |
+
let norm_y = p[1].clamp(0.0, 1.0 - 1e-5);
|
| 100 |
|
| 101 |
+
let gx = (norm_x * grid_size as f32).floor() as usize;
|
| 102 |
+
let gy = (norm_y * grid_size as f32).floor() as usize;
|
|
|
|
| 103 |
|
| 104 |
+
// Grid-relative offset (0..1 within cell)
|
| 105 |
+
let tx = norm_x * grid_size as f32 - gx as f32;
|
| 106 |
+
let ty = norm_y * grid_size as f32 - gy as f32;
|
| 107 |
+
|
| 108 |
+
// Python convention: cal points i=0..3 -> cls=1..4, dart i>=4 -> cls=0
|
| 109 |
let cls = if i < 4 { i + 1 } else { 0 };
|
| 110 |
+
|
| 111 |
+
// Assign this keypoint to anchor (cls % num_anchors) so all 3 anchors get used
|
| 112 |
+
let anchor_idx = cls % num_anchors;
|
| 113 |
+
|
| 114 |
+
// Flat index layout: [batch, anchor, attr, gy, gx]
|
| 115 |
+
// => flat = b * (3*10*G*G) + anchor * (10*G*G) + attr * (G*G) + gy*G + gx
|
| 116 |
+
let cell_base = b_idx * num_channels * grid_size * grid_size
|
| 117 |
+
+ anchor_idx * num_attrs * grid_size * grid_size
|
| 118 |
+
+ gy * grid_size
|
| 119 |
+
+ gx;
|
| 120 |
+
|
| 121 |
+
target_raw[cell_base + 0 * grid_size * grid_size] = tx; // x offset
|
| 122 |
+
target_raw[cell_base + 1 * grid_size * grid_size] = ty; // y offset
|
| 123 |
+
target_raw[cell_base + 2 * grid_size * grid_size] = 0.025; // w (bbox_size from config)
|
| 124 |
+
target_raw[cell_base + 3 * grid_size * grid_size] = 0.025; // h
|
| 125 |
+
target_raw[cell_base + 4 * grid_size * grid_size] = 1.0; // objectness
|
| 126 |
+
target_raw[cell_base + (5 + cls) * grid_size * grid_size] = 1.0; // class prob
|
| 127 |
}
|
| 128 |
}
|
| 129 |
|
src/inference.rs
CHANGED
|
@@ -17,7 +17,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 17 |
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
-
let resized = img.resize_exact(
|
| 21 |
let pixels: Vec<f32> = resized
|
| 22 |
.to_rgb8()
|
| 23 |
.pixels()
|
|
@@ -30,7 +30,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 30 |
})
|
| 31 |
.collect();
|
| 32 |
|
| 33 |
-
let data = TensorData::new(pixels, [
|
| 34 |
let input = Tensor::<B, 3>::from_data(data, device)
|
| 35 |
.unsqueeze::<4>()
|
| 36 |
.permute([0, 3, 1, 2]);
|
|
@@ -38,12 +38,12 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 38 |
println!("🚀 Running MODEL Prediction...");
|
| 39 |
let (out16, _out32) = model.forward(input);
|
| 40 |
|
| 41 |
-
// out16 shape: [1, 30,
|
| 42 |
-
//
|
| 43 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 44 |
|
| 45 |
-
//
|
| 46 |
-
let (max_val, _) = obj.reshape([
|
| 47 |
let confidence: f32 = max_val
|
| 48 |
.to_data()
|
| 49 |
.convert::<f32>()
|
|
|
|
| 17 |
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
+
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 21 |
let pixels: Vec<f32> = resized
|
| 22 |
.to_rgb8()
|
| 23 |
.pixels()
|
|
|
|
| 30 |
})
|
| 31 |
.collect();
|
| 32 |
|
| 33 |
+
let data = TensorData::new(pixels, [800, 800, 3]);
|
| 34 |
let input = Tensor::<B, 3>::from_data(data, device)
|
| 35 |
.unsqueeze::<4>()
|
| 36 |
.permute([0, 3, 1, 2]);
|
|
|
|
| 38 |
println!("🚀 Running MODEL Prediction...");
|
| 39 |
let (out16, _out32) = model.forward(input);
|
| 40 |
|
| 41 |
+
// out16 shape: [1, 30, 50, 50] — 800/16 = 50
|
| 42 |
+
// Extract Objectness (Channel 4 of first anchor)
|
| 43 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 44 |
|
| 45 |
+
// Find highest confidence cell in 50x50 grid
|
| 46 |
+
let (max_val, _) = obj.reshape([1_usize, 2500]).max_dim_with_indices(1);
|
| 47 |
let confidence: f32 = max_val
|
| 48 |
.to_data()
|
| 49 |
.convert::<f32>()
|
src/main.rs
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
|
|
|
|
|
| 1 |
use rust_auto_score_engine::args::{AppArgs, Command};
|
| 2 |
use rust_auto_score_engine::server::start_gui;
|
| 3 |
-
use rust_auto_score_engine::train::{train, TrainingConfig};
|
| 4 |
use rust_auto_score_engine::tests::test_model;
|
| 5 |
-
use
|
| 6 |
-
use burn::backend::Wgpu;
|
| 7 |
|
| 8 |
fn main() {
|
| 9 |
let app_args = AppArgs::parse();
|
|
|
|
| 1 |
+
use burn::backend::wgpu::WgpuDevice;
|
| 2 |
+
use burn::backend::Wgpu;
|
| 3 |
use rust_auto_score_engine::args::{AppArgs, Command};
|
| 4 |
use rust_auto_score_engine::server::start_gui;
|
|
|
|
| 5 |
use rust_auto_score_engine::tests::test_model;
|
| 6 |
+
use rust_auto_score_engine::train::{train, TrainingConfig};
|
|
|
|
| 7 |
|
| 8 |
fn main() {
|
| 9 |
let app_args = AppArgs::parse();
|
src/model.rs
CHANGED
|
@@ -28,61 +28,64 @@ impl<B: Backend> ConvBlock<B> {
|
|
| 28 |
}
|
| 29 |
}
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
#[derive(Module, Debug)]
|
| 32 |
pub struct DartVisionModel<B: Backend> {
|
| 33 |
-
//
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
head_32: Conv2d<B>, // Final detection head (30 channels for 3 anchors)
|
| 46 |
}
|
| 47 |
|
| 48 |
impl<B: Backend> DartVisionModel<B> {
|
| 49 |
pub fn new(device: &B::Device) -> Self {
|
| 50 |
-
let l1 = ConvBlock::new(3,
|
| 51 |
let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 52 |
-
|
| 53 |
-
let l2 = ConvBlock::new(32,
|
| 54 |
let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 55 |
-
|
| 56 |
-
let l3 = ConvBlock::new(32,
|
| 57 |
let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 58 |
-
|
| 59 |
-
let l4 = ConvBlock::new(64,
|
| 60 |
let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 61 |
-
|
| 62 |
-
let l5 = ConvBlock::new(64,
|
| 63 |
let l6 = ConvBlock::new(128, 128, [3, 3], device);
|
| 64 |
|
| 65 |
-
// 30
|
| 66 |
-
let
|
| 67 |
|
| 68 |
-
Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6,
|
| 69 |
}
|
| 70 |
|
|
|
|
| 71 |
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4>) {
|
| 72 |
-
let x = self.l1.forward(x);
|
| 73 |
-
let x = self.p1.forward(x);
|
| 74 |
-
let x = self.l2.forward(x);
|
| 75 |
-
let x = self.p2.forward(x);
|
| 76 |
-
let x = self.l3.forward(x);
|
| 77 |
-
let x = self.p3.forward(x);
|
| 78 |
-
let x = self.l4.forward(x);
|
| 79 |
-
let x = self.p4.forward(x);
|
| 80 |
-
|
| 81 |
-
let
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
let
|
| 85 |
-
|
| 86 |
-
(
|
| 87 |
}
|
| 88 |
}
|
|
|
|
| 28 |
}
|
| 29 |
}
|
| 30 |
|
| 31 |
+
/// DartVision model ported from YOLOv4-tiny.
|
| 32 |
+
/// Input: [B, 3, 800, 800] (matching Python config: input_size=800)
|
| 33 |
+
/// Output grid: [B, 30, 50, 50] — 800 / 2^4 = 50
|
| 34 |
+
/// 30 channels = 3 anchors × 10 attrs (x, y, w, h, obj, cls0..cls4)
|
| 35 |
#[derive(Module, Debug)]
|
| 36 |
pub struct DartVisionModel<B: Backend> {
|
| 37 |
+
l1: ConvBlock<B>, // 3 -> 32
|
| 38 |
+
p1: MaxPool2d, // /2 -> 400
|
| 39 |
+
l2: ConvBlock<B>, // 32 -> 32
|
| 40 |
+
p2: MaxPool2d, // /2 -> 200
|
| 41 |
+
l3: ConvBlock<B>, // 32 -> 64
|
| 42 |
+
p3: MaxPool2d, // /2 -> 100
|
| 43 |
+
l4: ConvBlock<B>, // 64 -> 64
|
| 44 |
+
p4: MaxPool2d, // /2 -> 50
|
| 45 |
+
l5: ConvBlock<B>, // 64 -> 128
|
| 46 |
+
l6: ConvBlock<B>, // 128 -> 128
|
| 47 |
+
head: Conv2d<B>, // 128 -> 30 (detection head)
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
impl<B: Backend> DartVisionModel<B> {
|
| 51 |
pub fn new(device: &B::Device) -> Self {
|
| 52 |
+
let l1 = ConvBlock::new(3, 32, [3, 3], device);
|
| 53 |
let p1 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 54 |
+
|
| 55 |
+
let l2 = ConvBlock::new(32, 32, [3, 3], device);
|
| 56 |
let p2 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 57 |
+
|
| 58 |
+
let l3 = ConvBlock::new(32, 64, [3, 3], device);
|
| 59 |
let p3 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 60 |
+
|
| 61 |
+
let l4 = ConvBlock::new(64, 64, [3, 3], device);
|
| 62 |
let p4 = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init();
|
| 63 |
+
|
| 64 |
+
let l5 = ConvBlock::new(64, 128, [3, 3], device);
|
| 65 |
let l6 = ConvBlock::new(128, 128, [3, 3], device);
|
| 66 |
|
| 67 |
+
// 30 = 3 anchors × (x, y, w, h, obj, dart, cal1, cal2, cal3, cal4)
|
| 68 |
+
let head = Conv2dConfig::new([128, 30], [1, 1]).init(device);
|
| 69 |
|
| 70 |
+
Self { l1, p1, l2, p2, l3, p3, l4, p4, l5, l6, head }
|
| 71 |
}
|
| 72 |
|
| 73 |
+
/// Returns (output_50, output_50) — second is a clone kept for API compat.
|
| 74 |
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 4>) {
|
| 75 |
+
let x = self.l1.forward(x); // [B, 32, 800, 800]
|
| 76 |
+
let x = self.p1.forward(x); // [B, 32, 400, 400]
|
| 77 |
+
let x = self.l2.forward(x); // [B, 32, 400, 400]
|
| 78 |
+
let x = self.p2.forward(x); // [B, 32, 200, 200]
|
| 79 |
+
let x = self.l3.forward(x); // [B, 64, 200, 200]
|
| 80 |
+
let x = self.p3.forward(x); // [B, 64, 100, 100]
|
| 81 |
+
let x = self.l4.forward(x); // [B, 64, 100, 100]
|
| 82 |
+
let x = self.p4.forward(x); // [B, 64, 50, 50]
|
| 83 |
+
let x = self.l5.forward(x); // [B, 128, 50, 50]
|
| 84 |
+
let x = self.l6.forward(x); // [B, 128, 50, 50]
|
| 85 |
+
// NOTE: Do NOT clone here — cloning an autodiff tensor duplicates the full
|
| 86 |
+
// computation graph in memory. train.rs only uses the first output.
|
| 87 |
+
let out = self.head.forward(x); // [B, 30, 50, 50]
|
| 88 |
+
let out2 = out.clone().detach(); // detached copy for API compat (no grad graph)
|
| 89 |
+
(out, out2)
|
| 90 |
}
|
| 91 |
}
|
src/server.rs
CHANGED
|
@@ -19,6 +19,7 @@ use tower_http::cors::CorsLayer;
|
|
| 19 |
struct PredictResult {
|
| 20 |
confidence: f32,
|
| 21 |
keypoints: Vec<f32>,
|
|
|
|
| 22 |
scores: Vec<String>,
|
| 23 |
}
|
| 24 |
|
|
@@ -50,7 +51,7 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 50 |
let start_time = std::time::Instant::now();
|
| 51 |
|
| 52 |
let img = image::load_from_memory(&req.image_bytes).unwrap();
|
| 53 |
-
let resized = img.resize_exact(
|
| 54 |
let pixels: Vec<f32> = resized
|
| 55 |
.to_rgb8()
|
| 56 |
.pixels()
|
|
@@ -63,14 +64,17 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 63 |
})
|
| 64 |
.collect();
|
| 65 |
|
| 66 |
-
let tensor_data = TensorData::new(pixels, [1,
|
| 67 |
let input =
|
| 68 |
Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
|
| 69 |
|
| 70 |
let (out16, _) = model.forward(input);
|
| 71 |
|
| 72 |
-
//
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
// 1.5 Debug: Raw Statistics
|
| 76 |
println!(
|
|
@@ -80,6 +84,7 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 80 |
);
|
| 81 |
|
| 82 |
let mut final_points = vec![0.0f32; 8]; // 4 corners
|
|
|
|
| 83 |
let mut max_conf = 0.0f32;
|
| 84 |
|
| 85 |
// 2. Extract best calibration corner for each class 1 to 4
|
|
@@ -101,14 +106,14 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 101 |
);
|
| 102 |
let score = obj.mul(prob);
|
| 103 |
|
| 104 |
-
let (val, idx) = score.reshape([
|
| 105 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 106 |
if s > best_s {
|
| 107 |
best_s = s;
|
| 108 |
best_anchor = anchor;
|
| 109 |
let f_idx =
|
| 110 |
idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 111 |
-
best_grid = (f_idx %
|
| 112 |
|
| 113 |
let sx = burn::tensor::activation::sigmoid(
|
| 114 |
out_reshaped
|
|
@@ -147,14 +152,16 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 147 |
|
| 148 |
// Reconstruct Absolute Normalized Coord (0-1)
|
| 149 |
best_pt = [
|
| 150 |
-
(best_grid.0 as f32 + sx) /
|
| 151 |
-
(best_grid.1 as f32 + sy) /
|
| 152 |
];
|
| 153 |
}
|
| 154 |
}
|
| 155 |
|
| 156 |
final_points[(cls_idx - 1) * 2] = best_pt[0];
|
| 157 |
final_points[(cls_idx - 1) * 2 + 1] = best_pt[1];
|
|
|
|
|
|
|
| 158 |
if best_s > max_conf {
|
| 159 |
max_conf = best_s;
|
| 160 |
}
|
|
@@ -215,14 +222,14 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 215 |
let prob = burn::tensor::activation::sigmoid(
|
| 216 |
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1),
|
| 217 |
);
|
| 218 |
-
let score = obj.mul(prob).reshape([
|
| 219 |
|
| 220 |
let (val, idx) = score.max_dim_with_indices(1);
|
| 221 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 222 |
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 223 |
|
| 224 |
-
let gx = f_idx %
|
| 225 |
-
let gy = f_idx /
|
| 226 |
|
| 227 |
let dsx = burn::tensor::activation::sigmoid(
|
| 228 |
out_reshaped
|
|
@@ -247,8 +254,8 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 247 |
.as_slice::<f32>()
|
| 248 |
.unwrap()[0];
|
| 249 |
|
| 250 |
-
let dx = (gx as f32 + dsx) /
|
| 251 |
-
let dy = (gy as f32 + dsy) /
|
| 252 |
|
| 253 |
if s > 0.005 {
|
| 254 |
println!(
|
|
@@ -266,6 +273,7 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 266 |
if *s > 0.05 {
|
| 267 |
final_points.push(pt[0]);
|
| 268 |
final_points.push(pt[1]);
|
|
|
|
| 269 |
println!(
|
| 270 |
" ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}",
|
| 271 |
s * 100.0,
|
|
@@ -317,6 +325,7 @@ pub async fn start_gui(device: WgpuDevice) {
|
|
| 317 |
let _ = req.response_tx.send(PredictResult {
|
| 318 |
confidence: max_conf,
|
| 319 |
keypoints: final_points,
|
|
|
|
| 320 |
scores: final_scores,
|
| 321 |
});
|
| 322 |
}
|
|
@@ -358,6 +367,7 @@ async fn predict_handler(
|
|
| 358 |
let result = res_rx.await.unwrap_or(PredictResult {
|
| 359 |
confidence: 0.0,
|
| 360 |
keypoints: vec![],
|
|
|
|
| 361 |
scores: vec![],
|
| 362 |
});
|
| 363 |
|
|
@@ -365,7 +375,9 @@ async fn predict_handler(
|
|
| 365 |
"status": "success",
|
| 366 |
"confidence": result.confidence,
|
| 367 |
"keypoints": result.keypoints,
|
|
|
|
| 368 |
"scores": result.scores,
|
|
|
|
| 369 |
"message": if result.confidence > 0.1 {
|
| 370 |
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
|
| 371 |
} else {
|
|
|
|
| 19 |
struct PredictResult {
|
| 20 |
confidence: f32,
|
| 21 |
keypoints: Vec<f32>,
|
| 22 |
+
confidences: Vec<f32>, // Individual confidence for each point
|
| 23 |
scores: Vec<String>,
|
| 24 |
}
|
| 25 |
|
|
|
|
| 51 |
let start_time = std::time::Instant::now();
|
| 52 |
|
| 53 |
let img = image::load_from_memory(&req.image_bytes).unwrap();
|
| 54 |
+
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 55 |
let pixels: Vec<f32> = resized
|
| 56 |
.to_rgb8()
|
| 57 |
.pixels()
|
|
|
|
| 64 |
})
|
| 65 |
.collect();
|
| 66 |
|
| 67 |
+
let tensor_data = TensorData::new(pixels, [1, 800, 800, 3]);
|
| 68 |
let input =
|
| 69 |
Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
|
| 70 |
|
| 71 |
let (out16, _) = model.forward(input);
|
| 72 |
|
| 73 |
+
// out16 shape: [1, 30, 50, 50] — 800/16 = 50
|
| 74 |
+
// Reshape to separate anchors: [1, 3, 10, 50, 50]
|
| 75 |
+
let out_reshaped = out16.reshape([1, 3, 10, 50, 50]);
|
| 76 |
+
let grid_size: usize = 50;
|
| 77 |
+
let num_cells: usize = grid_size * grid_size; // 2500
|
| 78 |
|
| 79 |
// 1.5 Debug: Raw Statistics
|
| 80 |
println!(
|
|
|
|
| 84 |
);
|
| 85 |
|
| 86 |
let mut final_points = vec![0.0f32; 8]; // 4 corners
|
| 87 |
+
let mut final_confs = vec![0.0f32; 4]; // 4 corner confs
|
| 88 |
let mut max_conf = 0.0f32;
|
| 89 |
|
| 90 |
// 2. Extract best calibration corner for each class 1 to 4
|
|
|
|
| 106 |
);
|
| 107 |
let score = obj.mul(prob);
|
| 108 |
|
| 109 |
+
let (val, idx) = score.reshape([1_usize, num_cells]).max_dim_with_indices(1);
|
| 110 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 111 |
if s > best_s {
|
| 112 |
best_s = s;
|
| 113 |
best_anchor = anchor;
|
| 114 |
let f_idx =
|
| 115 |
idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 116 |
+
best_grid = (f_idx % grid_size, f_idx / grid_size);
|
| 117 |
|
| 118 |
let sx = burn::tensor::activation::sigmoid(
|
| 119 |
out_reshaped
|
|
|
|
| 152 |
|
| 153 |
// Reconstruct Absolute Normalized Coord (0-1)
|
| 154 |
best_pt = [
|
| 155 |
+
(best_grid.0 as f32 + sx) / grid_size as f32,
|
| 156 |
+
(best_grid.1 as f32 + sy) / grid_size as f32,
|
| 157 |
];
|
| 158 |
}
|
| 159 |
}
|
| 160 |
|
| 161 |
final_points[(cls_idx - 1) * 2] = best_pt[0];
|
| 162 |
final_points[(cls_idx - 1) * 2 + 1] = best_pt[1];
|
| 163 |
+
final_confs[cls_idx - 1] = best_s;
|
| 164 |
+
|
| 165 |
if best_s > max_conf {
|
| 166 |
max_conf = best_s;
|
| 167 |
}
|
|
|
|
| 222 |
let prob = burn::tensor::activation::sigmoid(
|
| 223 |
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1),
|
| 224 |
);
|
| 225 |
+
let score = obj.mul(prob).reshape([1_usize, num_cells]);
|
| 226 |
|
| 227 |
let (val, idx) = score.max_dim_with_indices(1);
|
| 228 |
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
|
| 229 |
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
|
| 230 |
|
| 231 |
+
let gx = f_idx % grid_size;
|
| 232 |
+
let gy = f_idx / grid_size;
|
| 233 |
|
| 234 |
let dsx = burn::tensor::activation::sigmoid(
|
| 235 |
out_reshaped
|
|
|
|
| 254 |
.as_slice::<f32>()
|
| 255 |
.unwrap()[0];
|
| 256 |
|
| 257 |
+
let dx = (gx as f32 + dsx) / grid_size as f32;
|
| 258 |
+
let dy = (gy as f32 + dsy) / grid_size as f32;
|
| 259 |
|
| 260 |
if s > 0.005 {
|
| 261 |
println!(
|
|
|
|
| 273 |
if *s > 0.05 {
|
| 274 |
final_points.push(pt[0]);
|
| 275 |
final_points.push(pt[1]);
|
| 276 |
+
final_confs.push(*s);
|
| 277 |
println!(
|
| 278 |
" ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}",
|
| 279 |
s * 100.0,
|
|
|
|
| 325 |
let _ = req.response_tx.send(PredictResult {
|
| 326 |
confidence: max_conf,
|
| 327 |
keypoints: final_points,
|
| 328 |
+
confidences: final_confs,
|
| 329 |
scores: final_scores,
|
| 330 |
});
|
| 331 |
}
|
|
|
|
| 367 |
let result = res_rx.await.unwrap_or(PredictResult {
|
| 368 |
confidence: 0.0,
|
| 369 |
keypoints: vec![],
|
| 370 |
+
confidences: vec![],
|
| 371 |
scores: vec![],
|
| 372 |
});
|
| 373 |
|
|
|
|
| 375 |
"status": "success",
|
| 376 |
"confidence": result.confidence,
|
| 377 |
"keypoints": result.keypoints,
|
| 378 |
+
"confidences": result.confidences,
|
| 379 |
"scores": result.scores,
|
| 380 |
+
"is_calibrated": result.confidences.iter().take(4).all(|&c| c > 0.05),
|
| 381 |
"message": if result.confidence > 0.1 {
|
| 382 |
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
|
| 383 |
} else {
|
src/tests.rs
CHANGED
|
@@ -20,9 +20,9 @@ pub fn test_model(device: WgpuDevice, img_path: &str) {
|
|
| 20 |
|
| 21 |
let img = image::open(img_path).unwrap_or_else(|_| {
|
| 22 |
println!("❌ Image not found at {}. Using random tensor.", img_path);
|
| 23 |
-
image::DynamicImage::new_rgb8(
|
| 24 |
});
|
| 25 |
-
let resized = img.resize_exact(
|
| 26 |
let pixels: Vec<f32> = resized
|
| 27 |
.to_rgb8()
|
| 28 |
.pixels()
|
|
@@ -35,12 +35,12 @@ pub fn test_model(device: WgpuDevice, img_path: &str) {
|
|
| 35 |
})
|
| 36 |
.collect();
|
| 37 |
|
| 38 |
-
let tensor_data = TensorData::new(pixels, [1,
|
| 39 |
let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &device).permute([0, 3, 1, 2]);
|
| 40 |
let (out, _): (Tensor<Wgpu, 4>, _) = model.forward(input);
|
| 41 |
|
| 42 |
let obj = burn::tensor::activation::sigmoid(out.clone().narrow(1, 4, 1));
|
| 43 |
-
let (max_val, _) = obj.reshape([
|
| 44 |
|
| 45 |
let score = max_val
|
| 46 |
.to_data()
|
|
|
|
| 20 |
|
| 21 |
let img = image::open(img_path).unwrap_or_else(|_| {
|
| 22 |
println!("❌ Image not found at {}. Using random tensor.", img_path);
|
| 23 |
+
image::DynamicImage::new_rgb8(800, 800)
|
| 24 |
});
|
| 25 |
+
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
|
| 26 |
let pixels: Vec<f32> = resized
|
| 27 |
.to_rgb8()
|
| 28 |
.pixels()
|
|
|
|
| 35 |
})
|
| 36 |
.collect();
|
| 37 |
|
| 38 |
+
let tensor_data = TensorData::new(pixels, [1, 800, 800, 3]);
|
| 39 |
let input = Tensor::<Wgpu, 4>::from_data(tensor_data, &device).permute([0, 3, 1, 2]);
|
| 40 |
let (out, _): (Tensor<Wgpu, 4>, _) = model.forward(input);
|
| 41 |
|
| 42 |
let obj = burn::tensor::activation::sigmoid(out.clone().narrow(1, 4, 1));
|
| 43 |
+
let (max_val, _) = obj.reshape([1_usize, 2500]).max_dim_with_indices(1);
|
| 44 |
|
| 45 |
let score = max_val
|
| 46 |
.to_data()
|
src/train.rs
CHANGED
|
@@ -66,13 +66,14 @@ pub fn train<B: AutodiffBackend>(device: Device<B>, dataset_path: &str, config:
|
|
| 66 |
let loss = diou_loss(out16, batch.targets);
|
| 67 |
batch_count += 1;
|
| 68 |
|
| 69 |
-
// Print every
|
| 70 |
if batch_count % 20 == 0 || batch_count == 1 {
|
|
|
|
| 71 |
println!(
|
| 72 |
" [Epoch {}] Batch {: >3} | Loss: {:.6}",
|
| 73 |
epoch,
|
| 74 |
batch_count,
|
| 75 |
-
|
| 76 |
);
|
| 77 |
}
|
| 78 |
|
|
|
|
| 66 |
let loss = diou_loss(out16, batch.targets);
|
| 67 |
batch_count += 1;
|
| 68 |
|
| 69 |
+
// Print every 20 batches — use detach() to avoid cloning the full autodiff graph
|
| 70 |
if batch_count % 20 == 0 || batch_count == 1 {
|
| 71 |
+
let loss_val = loss.clone().detach().into_scalar();
|
| 72 |
println!(
|
| 73 |
" [Epoch {}] Batch {: >3} | Loss: {:.6}",
|
| 74 |
epoch,
|
| 75 |
batch_count,
|
| 76 |
+
loss_val
|
| 77 |
);
|
| 78 |
}
|
| 79 |
|
static/index.html
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8">
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
-
<title>
|
| 7 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 8 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 9 |
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
|
|
@@ -235,8 +235,8 @@
|
|
| 235 |
}
|
| 236 |
|
| 237 |
@keyframes pulse-marker {
|
| 238 |
-
0%, 100% { r:
|
| 239 |
-
50% { r:
|
| 240 |
}
|
| 241 |
|
| 242 |
.result-item {
|
|
@@ -269,7 +269,7 @@
|
|
| 269 |
<body>
|
| 270 |
<div class="bg-grid"></div>
|
| 271 |
<header>
|
| 272 |
-
<h1>
|
| 273 |
<p class="subtitle">Neural Scoring & Board Analytics</p>
|
| 274 |
</header>
|
| 275 |
|
|
@@ -296,6 +296,10 @@
|
|
| 296 |
<div class="stat-label">Model Status</div>
|
| 297 |
<div class="stat-value" id="status-text" style="font-size: 1.4rem; color: var(--primary);">System Ready</div>
|
| 298 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
<div class="stat-card">
|
| 300 |
<div class="stat-label">AI Confidence</div>
|
| 301 |
<div class="stat-value" id="conf-val">0.0%</div>
|
|
@@ -358,7 +362,7 @@
|
|
| 358 |
document.getElementById('status-text').innerText = 'Analysis Complete';
|
| 359 |
document.getElementById('status-text').style.color = 'var(--primary)';
|
| 360 |
updateUI(data);
|
| 361 |
-
drawKeypoints(data.keypoints);
|
| 362 |
} else {
|
| 363 |
document.getElementById('status-text').innerText = 'Analysis Failed';
|
| 364 |
document.getElementById('status-text').style.color = 'var(--accent)';
|
|
@@ -375,6 +379,15 @@
|
|
| 375 |
const conf = (data.confidence * 100).toFixed(1);
|
| 376 |
document.getElementById('conf-val').innerText = `${conf}%`;
|
| 377 |
document.getElementById('conf-fill').style.width = `${conf}%`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
let resultHtml = `<div style="margin: 1.5rem 0 1rem 0; font-size: 0.95rem; line-height: 1.6; color: rgba(255,255,255,0.9);">${data.message}</div>`;
|
| 380 |
if (data.keypoints && data.keypoints.length >= 8) {
|
|
@@ -387,12 +400,17 @@
|
|
| 387 |
const name = names[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
|
| 388 |
const x = data.keypoints[i].toFixed(3);
|
| 389 |
const y = data.keypoints[i+1].toFixed(3);
|
|
|
|
|
|
|
| 390 |
|
| 391 |
let scoreHtml = "";
|
| 392 |
if (classIdx >= 4 && data.scores && data.scores[classIdx - 4]) {
|
| 393 |
-
scoreHtml = `<
|
|
|
|
|
|
|
|
|
|
| 394 |
} else if (isCal) {
|
| 395 |
-
scoreHtml = `<span class="badge
|
| 396 |
}
|
| 397 |
|
| 398 |
resultHtml += `
|
|
@@ -412,7 +430,7 @@
|
|
| 412 |
while (svgOverlay.firstChild) svgOverlay.removeChild(svgOverlay.firstChild);
|
| 413 |
}
|
| 414 |
|
| 415 |
-
function drawKeypoints(pts) {
|
| 416 |
clearKeypoints();
|
| 417 |
if (!pts || pts.length === 0) return;
|
| 418 |
|
|
@@ -425,20 +443,21 @@
|
|
| 425 |
const width = rect.width;
|
| 426 |
const height = rect.height;
|
| 427 |
|
| 428 |
-
const classNames = ["CALIBRATION CORNER 1", "CALIBRATION CORNER 2", "CALIBRATION CORNER 3", "CALIBRATION CORNER 4", "DART POINT"];
|
| 429 |
for (let i = 0; i < pts.length; i += 2) {
|
| 430 |
const classIdx = i / 2;
|
| 431 |
const isCal = classIdx < 4;
|
| 432 |
const x = pts[i] * width + offsetX;
|
| 433 |
const y = pts[i+1] * height + offsetY;
|
| 434 |
-
|
|
|
|
|
|
|
| 435 |
|
| 436 |
const group = document.createElementNS("http://www.w3.org/2000/svg", "g");
|
| 437 |
|
| 438 |
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
|
| 439 |
circle.setAttribute("cx", x);
|
| 440 |
circle.setAttribute("cy", y);
|
| 441 |
-
circle.setAttribute("r",
|
| 442 |
circle.setAttribute("class", "keypoint-marker");
|
| 443 |
if (!isCal) {
|
| 444 |
circle.style.fill = "#ff4d4d";
|
|
@@ -448,7 +467,7 @@
|
|
| 448 |
const labelBg = document.createElementNS("http://www.w3.org/2000/svg", "rect");
|
| 449 |
labelBg.setAttribute("x", x + 15);
|
| 450 |
labelBg.setAttribute("y", y - 25);
|
| 451 |
-
labelBg.setAttribute("width", name.length * 7 +
|
| 452 |
labelBg.setAttribute("height", "22");
|
| 453 |
labelBg.setAttribute("rx", "11");
|
| 454 |
labelBg.setAttribute("fill", "rgba(0,0,0,0.7)");
|
|
|
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8">
|
| 5 |
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Rust AutoScore Engine - Smart Dashboard</title>
|
| 7 |
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 8 |
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 9 |
<link href="https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;800&display=swap" rel="stylesheet">
|
|
|
|
| 235 |
}
|
| 236 |
|
| 237 |
@keyframes pulse-marker {
|
| 238 |
+
0%, 100% { r: 8; opacity: 1; }
|
| 239 |
+
50% { r: 8; opacity: 0.8; }
|
| 240 |
}
|
| 241 |
|
| 242 |
.result-item {
|
|
|
|
| 269 |
<body>
|
| 270 |
<div class="bg-grid"></div>
|
| 271 |
<header>
|
| 272 |
+
<h1>RUST AUTO SCORE <span style="font-weight: 200; opacity: 0.4;">ENGINE</span></h1>
|
| 273 |
<p class="subtitle">Neural Scoring & Board Analytics</p>
|
| 274 |
</header>
|
| 275 |
|
|
|
|
| 296 |
<div class="stat-label">Model Status</div>
|
| 297 |
<div class="stat-value" id="status-text" style="font-size: 1.4rem; color: var(--primary);">System Ready</div>
|
| 298 |
</div>
|
| 299 |
+
<div class="stat-card" id="cal-card">
|
| 300 |
+
<div class="stat-label">Calibration Sync</div>
|
| 301 |
+
<div class="stat-value" id="cal-status" style="font-size: 1.2rem; color: #8892b0;">Pending...</div>
|
| 302 |
+
</div>
|
| 303 |
<div class="stat-card">
|
| 304 |
<div class="stat-label">AI Confidence</div>
|
| 305 |
<div class="stat-value" id="conf-val">0.0%</div>
|
|
|
|
| 362 |
document.getElementById('status-text').innerText = 'Analysis Complete';
|
| 363 |
document.getElementById('status-text').style.color = 'var(--primary)';
|
| 364 |
updateUI(data);
|
| 365 |
+
drawKeypoints(data.keypoints, data.confidences);
|
| 366 |
} else {
|
| 367 |
document.getElementById('status-text').innerText = 'Analysis Failed';
|
| 368 |
document.getElementById('status-text').style.color = 'var(--accent)';
|
|
|
|
| 379 |
const conf = (data.confidence * 100).toFixed(1);
|
| 380 |
document.getElementById('conf-val').innerText = `${conf}%`;
|
| 381 |
document.getElementById('conf-fill').style.width = `${conf}%`;
|
| 382 |
+
|
| 383 |
+
const calStatus = document.getElementById('cal-status');
|
| 384 |
+
if (data.is_calibrated) {
|
| 385 |
+
calStatus.innerText = "VERIFIED ✅";
|
| 386 |
+
calStatus.style.color = "var(--primary)";
|
| 387 |
+
} else {
|
| 388 |
+
calStatus.innerText = "FAILED ❌";
|
| 389 |
+
calStatus.style.color = "var(--accent)";
|
| 390 |
+
}
|
| 391 |
|
| 392 |
let resultHtml = `<div style="margin: 1.5rem 0 1rem 0; font-size: 0.95rem; line-height: 1.6; color: rgba(255,255,255,0.9);">${data.message}</div>`;
|
| 393 |
if (data.keypoints && data.keypoints.length >= 8) {
|
|
|
|
| 400 |
const name = names[classIdx] || `Dart ${Math.floor(classIdx - 3)}`;
|
| 401 |
const x = data.keypoints[i].toFixed(3);
|
| 402 |
const y = data.keypoints[i+1].toFixed(3);
|
| 403 |
+
const ptConf = ((data.confidences[classIdx] || 0) * 100).toFixed(0);
|
| 404 |
+
const isReliable = ptConf > 10;
|
| 405 |
|
| 406 |
let scoreHtml = "";
|
| 407 |
if (classIdx >= 4 && data.scores && data.scores[classIdx - 4]) {
|
| 408 |
+
scoreHtml = `<div style="display: flex; flex-direction: column; align-items: flex-end; gap: 4px;">
|
| 409 |
+
<span class="badge badge-dart">${data.scores[classIdx - 4]}</span>
|
| 410 |
+
<span style="font-size: 0.65rem; color: #8892b0; font-weight: 600;">CONF: ${ptConf}%</span>
|
| 411 |
+
</div>`;
|
| 412 |
} else if (isCal) {
|
| 413 |
+
scoreHtml = `<span class="badge" style="background: ${isReliable ? 'rgba(0,255,136,0.1)' : 'rgba(255,77,77,0.1)'}; color: ${isReliable ? 'var(--primary)' : 'var(--accent)'}; font-size: 0.6rem; border: 1px solid">${isReliable ? ptConf+'% OK' : ptConf+'% ERR'}</span>`;
|
| 414 |
}
|
| 415 |
|
| 416 |
resultHtml += `
|
|
|
|
| 430 |
while (svgOverlay.firstChild) svgOverlay.removeChild(svgOverlay.firstChild);
|
| 431 |
}
|
| 432 |
|
| 433 |
+
function drawKeypoints(pts, confs) {
|
| 434 |
clearKeypoints();
|
| 435 |
if (!pts || pts.length === 0) return;
|
| 436 |
|
|
|
|
| 443 |
const width = rect.width;
|
| 444 |
const height = rect.height;
|
| 445 |
|
|
|
|
| 446 |
for (let i = 0; i < pts.length; i += 2) {
|
| 447 |
const classIdx = i / 2;
|
| 448 |
const isCal = classIdx < 4;
|
| 449 |
const x = pts[i] * width + offsetX;
|
| 450 |
const y = pts[i+1] * height + offsetY;
|
| 451 |
+
|
| 452 |
+
const ptConf = ((confs[classIdx] || 0) * 100).toFixed(0);
|
| 453 |
+
const name = isCal ? `CAL ${classIdx+1} (${ptConf}%)` : `DART (${ptConf}%)`;
|
| 454 |
|
| 455 |
const group = document.createElementNS("http://www.w3.org/2000/svg", "g");
|
| 456 |
|
| 457 |
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
|
| 458 |
circle.setAttribute("cx", x);
|
| 459 |
circle.setAttribute("cy", y);
|
| 460 |
+
circle.setAttribute("r", 8);
|
| 461 |
circle.setAttribute("class", "keypoint-marker");
|
| 462 |
if (!isCal) {
|
| 463 |
circle.style.fill = "#ff4d4d";
|
|
|
|
| 467 |
const labelBg = document.createElementNS("http://www.w3.org/2000/svg", "rect");
|
| 468 |
labelBg.setAttribute("x", x + 15);
|
| 469 |
labelBg.setAttribute("y", y - 25);
|
| 470 |
+
labelBg.setAttribute("width", name.length * 7 + 15);
|
| 471 |
labelBg.setAttribute("height", "22");
|
| 472 |
labelBg.setAttribute("rx", "11");
|
| 473 |
labelBg.setAttribute("fill", "rgba(0,0,0,0.7)");
|