use crate::model::DartVisionModel; use axum::{ extract::{DefaultBodyLimit, Multipart, State}, response::{Html, Json}, routing::{get, post}, Router, }; use burn::backend::wgpu::WgpuDevice; use burn::backend::Wgpu; use burn::prelude::*; use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; use serde_json::json; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use tower_http::cors::CorsLayer; #[derive(Debug)] struct PredictResult { confidence: f32, keypoints: Vec, confidences: Vec, // Individual confidence for each point scores: Vec, } struct PredictRequest { image_bytes: Vec, response_tx: oneshot::Sender, } pub async fn start_gui(device: WgpuDevice) { let port = std::env::var("PORT") .unwrap_or_else(|_| "8080".to_string()) .parse::() .unwrap_or(8080); let addr = SocketAddr::from(([0, 0, 0, 0], port)); println!("🚀 [DartVision-GUI] Starting on http://0.0.0.0:{}", port); let (tx, mut rx) = mpsc::channel::(10); let worker_device = device.clone(); std::thread::spawn(move || { let recorder = BinFileRecorder::::default(); let model = DartVisionModel::::new(&worker_device); let record = match recorder.load("model_weights".into(), &worker_device) { Ok(r) => r, Err(_) => { println!("⚠️ [DartVision] No 'model_weights.bin' yet. Using initial weights..."); model.clone().into_record() } }; let model = model.load_record(record); while let Some(req) = rx.blocking_recv() { let start_time = std::time::Instant::now(); let img = image::load_from_memory(&req.image_bytes).unwrap(); let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle); let pixels: Vec = resized .to_rgb8() .pixels() .flat_map(|p| { vec![ p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0, ] }) .collect(); let tensor_data = TensorData::new(pixels, [1, 800, 800, 3]); let input = Tensor::::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]); let (out16, _) = model.forward(input); // out16 shape: [1, 30, 50, 50] — 800/16 = 50 // Reshape to separate anchors: [1, 3, 10, 50, 50] let out_reshaped = out16.reshape([1, 3, 10, 50, 50]); let grid_size: usize = 50; let num_cells: usize = grid_size * grid_size; // 2500 // 1.5 Debug: Raw Statistics println!( "🔍 [Model Stats] Raw Min: {:.4}, Max: {:.4}", out_reshaped.clone().min().into_scalar(), out_reshaped.clone().max().into_scalar() ); let mut final_points = vec![0.0f32; 8]; // 4 corners let mut final_confs = vec![0.0f32; 4]; // 4 corner confs let mut max_conf = 0.0f32; // 2. Extract best calibration corner for each class 1 to 4 for cls_idx in 1..=4 { let mut best_s = -1.0f32; let mut best_pt = [0.0f32; 2]; let mut best_anchor = 0; let mut best_grid = (0, 0); for anchor in 0..3 { let obj = burn::tensor::activation::sigmoid( out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1), ); let prob = burn::tensor::activation::sigmoid( out_reshaped .clone() .narrow(1, anchor, 1) .narrow(2, 5 + cls_idx, 1), ); let score = obj.mul(prob); let (val, idx) = score.reshape([1_usize, num_cells]).max_dim_with_indices(1); let s = val.to_data().convert::().as_slice::().unwrap()[0]; if s > best_s { best_s = s; best_anchor = anchor; let f_idx = idx.to_data().convert::().as_slice::().unwrap()[0] as usize; best_grid = (f_idx % grid_size, f_idx / grid_size); let sx = burn::tensor::activation::sigmoid( out_reshaped .clone() .narrow(1, anchor, 1) .narrow(2, 0, 1) .slice([ 0..1, 0..1, 0..1, best_grid.1..best_grid.1 + 1, best_grid.0..best_grid.0 + 1, ]), ) .to_data() .convert::() .as_slice::() .unwrap()[0]; let sy = burn::tensor::activation::sigmoid( out_reshaped .clone() .narrow(1, anchor, 1) .narrow(2, 1, 1) .slice([ 0..1, 0..1, 0..1, best_grid.1..best_grid.1 + 1, best_grid.0..best_grid.0 + 1, ]), ) .to_data() .convert::() .as_slice::() .unwrap()[0]; // Reconstruct Absolute Normalized Coord (0-1) best_pt = [ (best_grid.0 as f32 + sx) / grid_size as f32, (best_grid.1 as f32 + sy) / grid_size as f32, ]; } } final_points[(cls_idx - 1) * 2] = best_pt[0]; final_points[(cls_idx - 1) * 2 + 1] = best_pt[1]; final_confs[cls_idx - 1] = best_s; if best_s > max_conf { max_conf = best_s; } println!( " [Debug Cal{}] Anchor: {}, Conf: {:.4}, Cell: {:?}, Coord: [{:.3}, {:.3}]", cls_idx, best_anchor, best_s, best_grid, best_pt[0], best_pt[1] ); } // 3. Calibration Estimation (Python logic: est_cal_pts) // If one calibration point is missing, estimate it using symmetry let mut valid_cal_count = 0; let mut missing_idx = -1; for i in 0..4 { if final_points[i*2] > 0.01 || final_points[i*2+1] > 0.01 { valid_cal_count += 1; } else { missing_idx = i as i32; } } if valid_cal_count == 3 { println!("⚠️ [Calibration Recovery] Estimating missing point Cal{}...", missing_idx + 1); match missing_idx { 0 | 1 => { // Top points missing, use bottom points center let cx = (final_points[4] + final_points[6]) / 2.0; let cy = (final_points[5] + final_points[7]) / 2.0; if missing_idx == 0 { final_points[0] = 2.0 * cx - final_points[2]; final_points[1] = 2.0 * cy - final_points[3]; } else { final_points[2] = 2.0 * cx - final_points[0]; final_points[3] = 2.0 * cy - final_points[1]; } }, 2 | 3 => { // Bottom points missing, use top points center let cx = (final_points[0] + final_points[2]) / 2.0; let cy = (final_points[1] + final_points[3]) / 2.0; if missing_idx == 2 { final_points[4] = 2.0 * cx - final_points[6]; final_points[5] = 2.0 * cy - final_points[7]; } else { final_points[6] = 2.0 * cx - final_points[4]; final_points[7] = 2.0 * cy - final_points[5]; } }, _ => {} } } // 4. Extract best dart (Class 0) - Find candidates across all anchors println!(" [Debug Dart] Searching for Candidates..."); let mut dart_candidates = vec![]; for anchor in 0..3 { let obj = burn::tensor::activation::sigmoid( out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1), ); let prob = burn::tensor::activation::sigmoid( out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1), ); let score = obj.mul(prob).reshape([1_usize, num_cells]); let (val, idx) = score.max_dim_with_indices(1); let s = val.to_data().convert::().as_slice::().unwrap()[0]; let f_idx = idx.to_data().convert::().as_slice::().unwrap()[0] as usize; let gx = f_idx % grid_size; let gy = f_idx / grid_size; let dsx = burn::tensor::activation::sigmoid( out_reshaped .clone() .narrow(1, anchor, 1) .narrow(2, 0, 1) .slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]), ) .to_data() .convert::() .as_slice::() .unwrap()[0]; let dsy = burn::tensor::activation::sigmoid( out_reshaped .clone() .narrow(1, anchor, 1) .narrow(2, 1, 1) .slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]), ) .to_data() .convert::() .as_slice::() .unwrap()[0]; let dx = (gx as f32 + dsx) / grid_size as f32; let dy = (gy as f32 + dsy) / grid_size as f32; if s > 0.005 { println!( " - A{} Best Cell: ({},{}), Conf: {:.4}, Coord: [{:.3}, {:.3}]", anchor, gx, gy, s, dx, dy ); dart_candidates.push((s, [dx, dy])); } } // Pick the best dart candidate among all anchors dart_candidates .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); if let Some((s, pt)) = dart_candidates.first() { if *s > 0.05 { final_points.push(pt[0]); final_points.push(pt[1]); final_confs.push(*s); println!( " ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}", s * 100.0, pt ); } } let mut final_scores = vec![]; // Calculate scores if we have calibration points and at least one dart if final_points.len() >= 10 { use crate::scoring::{calculate_dart_score, ScoringConfig}; let config = ScoringConfig::default(); let cal_pts = [ [final_points[0], final_points[1]], [final_points[2], final_points[3]], [final_points[4], final_points[5]], [final_points[6], final_points[7]], ]; for dart_chunk in final_points[8..].chunks(2) { if dart_chunk.len() == 2 { let dart_pt = [dart_chunk[0], dart_chunk[1]]; let (s_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config); final_scores.push(label.clone()); println!(" [Debug Score] Label: {} (Val: {})", label, s_val); } } } let duration = start_time.elapsed(); println!("⚡ [Inference Performance] Total Latency: {:.2?}", duration); println!("🎯 [Final Result] Top Confidence: {:.2}%", max_conf * 100.0); let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"]; for (i, pts) in final_points.chunks(2).enumerate() { let name = class_names.get(i).unwrap_or(&"Dart"); let label = final_scores .get(i.saturating_sub(4)) .cloned() .unwrap_or_default(); println!( " - {}: [x: {:.3}, y: {:.3}] {}", name, pts[0], pts[1], label ); } let _ = req.response_tx.send(PredictResult { confidence: max_conf, keypoints: final_points, confidences: final_confs, scores: final_scores, }); } }); let state = Arc::new(tx); let app = Router::new() .route( "/", get(|| async { Html(include_str!("../static/index.html")) }), ) .route("/api/predict", post(predict_handler)) .with_state(state) .layer(DefaultBodyLimit::max(10 * 1024 * 1024)) .layer(CorsLayer::permissive()); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); } async fn predict_handler( State(tx): State>>, mut multipart: Multipart, ) -> Json { while let Ok(Some(field)) = multipart.next_field().await { if field.name() == Some("image") { let bytes = match field.bytes().await { Ok(b) => b.to_vec(), Err(_) => continue, }; let (res_tx, res_rx) = oneshot::channel(); let _ = tx .send(PredictRequest { image_bytes: bytes, response_tx: res_tx, }) .await; let result = res_rx.await.unwrap_or(PredictResult { confidence: 0.0, keypoints: vec![], confidences: vec![], scores: vec![], }); return Json(json!({ "status": "success", "confidence": result.confidence, "keypoints": result.keypoints, "confidences": result.confidences, "scores": result.scores, "is_calibrated": result.confidences.iter().take(4).all(|&c| c > 0.05), "message": if result.confidence > 0.1 { format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0) } else { "⚠️ Low confidence detection - no dart score could be verified.".to_string() } })); } } Json(json!({"status": "error", "message": "No image field found"})) }