RustAutoScoreEngine / src /server.rs
kapil
Add Dockerfile and HF Space config
6b285c0
use crate::model::DartVisionModel;
use axum::{
extract::{DefaultBodyLimit, Multipart, State},
response::{Html, Json},
routing::{get, post},
Router,
};
use burn::backend::wgpu::WgpuDevice;
use burn::backend::Wgpu;
use burn::prelude::*;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tower_http::cors::CorsLayer;
#[derive(Debug)]
struct PredictResult {
confidence: f32,
keypoints: Vec<f32>,
confidences: Vec<f32>, // Individual confidence for each point
scores: Vec<String>,
}
struct PredictRequest {
image_bytes: Vec<u8>,
response_tx: oneshot::Sender<PredictResult>,
}
pub async fn start_gui(device: WgpuDevice) {
let port = std::env::var("PORT")
.unwrap_or_else(|_| "8080".to_string())
.parse::<u16>()
.unwrap_or(8080);
let addr = SocketAddr::from(([0, 0, 0, 0], port));
println!("🚀 [DartVision-GUI] Starting on http://0.0.0.0:{}", port);
let (tx, mut rx) = mpsc::channel::<PredictRequest>(10);
let worker_device = device.clone();
std::thread::spawn(move || {
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
let model = DartVisionModel::<Wgpu>::new(&worker_device);
let record = match recorder.load("model_weights".into(), &worker_device) {
Ok(r) => r,
Err(_) => {
println!("⚠️ [DartVision] No 'model_weights.bin' yet. Using initial weights...");
model.clone().into_record()
}
};
let model = model.load_record(record);
while let Some(req) = rx.blocking_recv() {
let start_time = std::time::Instant::now();
let img = image::load_from_memory(&req.image_bytes).unwrap();
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
let pixels: Vec<f32> = resized
.to_rgb8()
.pixels()
.flat_map(|p| {
vec![
p[0] as f32 / 255.0,
p[1] as f32 / 255.0,
p[2] as f32 / 255.0,
]
})
.collect();
let tensor_data = TensorData::new(pixels, [1, 800, 800, 3]);
let input =
Tensor::<Wgpu, 4>::from_data(tensor_data, &worker_device).permute([0, 3, 1, 2]);
let (out16, _) = model.forward(input);
// out16 shape: [1, 30, 50, 50] — 800/16 = 50
// Reshape to separate anchors: [1, 3, 10, 50, 50]
let out_reshaped = out16.reshape([1, 3, 10, 50, 50]);
let grid_size: usize = 50;
let num_cells: usize = grid_size * grid_size; // 2500
// 1.5 Debug: Raw Statistics
println!(
"🔍 [Model Stats] Raw Min: {:.4}, Max: {:.4}",
out_reshaped.clone().min().into_scalar(),
out_reshaped.clone().max().into_scalar()
);
let mut final_points = vec![0.0f32; 8]; // 4 corners
let mut final_confs = vec![0.0f32; 4]; // 4 corner confs
let mut max_conf = 0.0f32;
// 2. Extract best calibration corner for each class 1 to 4
for cls_idx in 1..=4 {
let mut best_s = -1.0f32;
let mut best_pt = [0.0f32; 2];
let mut best_anchor = 0;
let mut best_grid = (0, 0);
for anchor in 0..3 {
let obj = burn::tensor::activation::sigmoid(
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
);
let prob = burn::tensor::activation::sigmoid(
out_reshaped
.clone()
.narrow(1, anchor, 1)
.narrow(2, 5 + cls_idx, 1),
);
let score = obj.mul(prob);
let (val, idx) = score.reshape([1_usize, num_cells]).max_dim_with_indices(1);
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
if s > best_s {
best_s = s;
best_anchor = anchor;
let f_idx =
idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
best_grid = (f_idx % grid_size, f_idx / grid_size);
let sx = burn::tensor::activation::sigmoid(
out_reshaped
.clone()
.narrow(1, anchor, 1)
.narrow(2, 0, 1)
.slice([
0..1,
0..1,
0..1,
best_grid.1..best_grid.1 + 1,
best_grid.0..best_grid.0 + 1,
]),
)
.to_data()
.convert::<f32>()
.as_slice::<f32>()
.unwrap()[0];
let sy = burn::tensor::activation::sigmoid(
out_reshaped
.clone()
.narrow(1, anchor, 1)
.narrow(2, 1, 1)
.slice([
0..1,
0..1,
0..1,
best_grid.1..best_grid.1 + 1,
best_grid.0..best_grid.0 + 1,
]),
)
.to_data()
.convert::<f32>()
.as_slice::<f32>()
.unwrap()[0];
// Reconstruct Absolute Normalized Coord (0-1)
best_pt = [
(best_grid.0 as f32 + sx) / grid_size as f32,
(best_grid.1 as f32 + sy) / grid_size as f32,
];
}
}
final_points[(cls_idx - 1) * 2] = best_pt[0];
final_points[(cls_idx - 1) * 2 + 1] = best_pt[1];
final_confs[cls_idx - 1] = best_s;
if best_s > max_conf {
max_conf = best_s;
}
println!(
" [Debug Cal{}] Anchor: {}, Conf: {:.4}, Cell: {:?}, Coord: [{:.3}, {:.3}]",
cls_idx, best_anchor, best_s, best_grid, best_pt[0], best_pt[1]
);
}
// 3. Calibration Estimation (Python logic: est_cal_pts)
// If one calibration point is missing, estimate it using symmetry
let mut valid_cal_count = 0;
let mut missing_idx = -1;
for i in 0..4 {
if final_points[i*2] > 0.01 || final_points[i*2+1] > 0.01 {
valid_cal_count += 1;
} else {
missing_idx = i as i32;
}
}
if valid_cal_count == 3 {
println!("⚠️ [Calibration Recovery] Estimating missing point Cal{}...", missing_idx + 1);
match missing_idx {
0 | 1 => { // Top points missing, use bottom points center
let cx = (final_points[4] + final_points[6]) / 2.0;
let cy = (final_points[5] + final_points[7]) / 2.0;
if missing_idx == 0 {
final_points[0] = 2.0 * cx - final_points[2];
final_points[1] = 2.0 * cy - final_points[3];
} else {
final_points[2] = 2.0 * cx - final_points[0];
final_points[3] = 2.0 * cy - final_points[1];
}
},
2 | 3 => { // Bottom points missing, use top points center
let cx = (final_points[0] + final_points[2]) / 2.0;
let cy = (final_points[1] + final_points[3]) / 2.0;
if missing_idx == 2 {
final_points[4] = 2.0 * cx - final_points[6];
final_points[5] = 2.0 * cy - final_points[7];
} else {
final_points[6] = 2.0 * cx - final_points[4];
final_points[7] = 2.0 * cy - final_points[5];
}
},
_ => {}
}
}
// 4. Extract best dart (Class 0) - Find candidates across all anchors
println!(" [Debug Dart] Searching for Candidates...");
let mut dart_candidates = vec![];
for anchor in 0..3 {
let obj = burn::tensor::activation::sigmoid(
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 4, 1),
);
let prob = burn::tensor::activation::sigmoid(
out_reshaped.clone().narrow(1, anchor, 1).narrow(2, 5, 1),
);
let score = obj.mul(prob).reshape([1_usize, num_cells]);
let (val, idx) = score.max_dim_with_indices(1);
let s = val.to_data().convert::<f32>().as_slice::<f32>().unwrap()[0];
let f_idx = idx.to_data().convert::<i32>().as_slice::<i32>().unwrap()[0] as usize;
let gx = f_idx % grid_size;
let gy = f_idx / grid_size;
let dsx = burn::tensor::activation::sigmoid(
out_reshaped
.clone()
.narrow(1, anchor, 1)
.narrow(2, 0, 1)
.slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
)
.to_data()
.convert::<f32>()
.as_slice::<f32>()
.unwrap()[0];
let dsy = burn::tensor::activation::sigmoid(
out_reshaped
.clone()
.narrow(1, anchor, 1)
.narrow(2, 1, 1)
.slice([0..1, 0..1, 0..1, gy..gy + 1, gx..gx + 1]),
)
.to_data()
.convert::<f32>()
.as_slice::<f32>()
.unwrap()[0];
let dx = (gx as f32 + dsx) / grid_size as f32;
let dy = (gy as f32 + dsy) / grid_size as f32;
if s > 0.005 {
println!(
" - A{} Best Cell: ({},{}), Conf: {:.4}, Coord: [{:.3}, {:.3}]",
anchor, gx, gy, s, dx, dy
);
dart_candidates.push((s, [dx, dy]));
}
}
// Pick the best dart candidate among all anchors
dart_candidates
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
if let Some((s, pt)) = dart_candidates.first() {
if *s > 0.05 {
final_points.push(pt[0]);
final_points.push(pt[1]);
final_confs.push(*s);
println!(
" ✅ Best Dart Picked: Conf: {:.2}%, Coord: {:?}",
s * 100.0,
pt
);
}
}
let mut final_scores = vec![];
// Calculate scores if we have calibration points and at least one dart
if final_points.len() >= 10 {
use crate::scoring::{calculate_dart_score, ScoringConfig};
let config = ScoringConfig::default();
let cal_pts = [
[final_points[0], final_points[1]],
[final_points[2], final_points[3]],
[final_points[4], final_points[5]],
[final_points[6], final_points[7]],
];
for dart_chunk in final_points[8..].chunks(2) {
if dart_chunk.len() == 2 {
let dart_pt = [dart_chunk[0], dart_chunk[1]];
let (s_val, label) = calculate_dart_score(&cal_pts, &dart_pt, &config);
final_scores.push(label.clone());
println!(" [Debug Score] Label: {} (Val: {})", label, s_val);
}
}
}
let duration = start_time.elapsed();
println!("⚡ [Inference Performance] Total Latency: {:.2?}", duration);
println!("🎯 [Final Result] Top Confidence: {:.2}%", max_conf * 100.0);
let class_names = ["Cal1", "Cal2", "Cal3", "Cal4", "Dart"];
for (i, pts) in final_points.chunks(2).enumerate() {
let name = class_names.get(i).unwrap_or(&"Dart");
let label = final_scores
.get(i.saturating_sub(4))
.cloned()
.unwrap_or_default();
println!(
" - {}: [x: {:.3}, y: {:.3}] {}",
name, pts[0], pts[1], label
);
}
let _ = req.response_tx.send(PredictResult {
confidence: max_conf,
keypoints: final_points,
confidences: final_confs,
scores: final_scores,
});
}
});
let state = Arc::new(tx);
let app = Router::new()
.route(
"/",
get(|| async { Html(include_str!("../static/index.html")) }),
)
.route("/api/predict", post(predict_handler))
.with_state(state)
.layer(DefaultBodyLimit::max(10 * 1024 * 1024))
.layer(CorsLayer::permissive());
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn predict_handler(
State(tx): State<Arc<mpsc::Sender<PredictRequest>>>,
mut multipart: Multipart,
) -> Json<serde_json::Value> {
while let Ok(Some(field)) = multipart.next_field().await {
if field.name() == Some("image") {
let bytes = match field.bytes().await {
Ok(b) => b.to_vec(),
Err(_) => continue,
};
let (res_tx, res_rx) = oneshot::channel();
let _ = tx
.send(PredictRequest {
image_bytes: bytes,
response_tx: res_tx,
})
.await;
let result = res_rx.await.unwrap_or(PredictResult {
confidence: 0.0,
keypoints: vec![],
confidences: vec![],
scores: vec![],
});
return Json(json!({
"status": "success",
"confidence": result.confidence,
"keypoints": result.keypoints,
"confidences": result.confidences,
"scores": result.scores,
"is_calibrated": result.confidences.iter().take(4).all(|&c| c > 0.05),
"message": if result.confidence > 0.1 {
format!("✅ Found {} darts! High confidence: {:.1}%", result.scores.len(), result.confidence * 100.0)
} else {
"⚠️ Low confidence detection - no dart score could be verified.".to_string()
}
}));
}
}
Json(json!({"status": "error", "message": "No image field found"}))
}