File size: 2,663 Bytes
f972a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
use crate::model::DartVisionModel;
use burn::prelude::*;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use wasm_bindgen::prelude::*;
use serde_json::json;

#[wasm_bindgen]
pub async fn init_vision_engine(weights_data: Vec<u8>) -> JsValue {
    // console_error_panic_hook for better browser debugging
    console_error_panic_hook::set_once();
    
    // Check if WebGPU or fallback is available
    let device = WgpuDevice::default();
    
    // JSON response for frontend
    let status = json!({
        "status": "online",
        "device": format!("{:?}", device),
        "message": "Rust Neural Engine initialized successfully in WASM"
    });
    
    serde_wasm_bindgen::to_value(&status).unwrap()
}

#[wasm_bindgen]
pub async fn predict_wasm(image_bytes: Vec<u8>, weights_bytes: Vec<u8>) -> JsValue {
    let device = WgpuDevice::default();
    
    // 1. Process Image from bytes (Browser environment)
    let img = image::load_from_memory(&image_bytes).expect("Failed to load image from memory");
    let input_res: usize = 800;
    let resized = img.resize_exact(input_res as u32, input_res as u32, image::imageops::FilterType::Triangle);
    let pixels: Vec<f32> = resized.to_rgb8().pixels()
        .flat_map(|p| vec![p[0] as f32 / 255.0, p[1] as f32 / 255.0, p[2] as f32 / 255.0])
        .collect();

    let data = TensorData::new(pixels, [input_res, input_res, 3]);
    let input = Tensor::<Wgpu, 3>::from_data(data, &device).unsqueeze::<4>().permute([0, 3, 1, 2]);

    // 2. Setup Model and Load Weights from the passed bytes
    let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
    let model = DartVisionModel::<Wgpu>::new(&device);
    
    // We use the recorder to load directly from the passed bytes in WASM
    // (In a real pro-WASM setup we'd keep the model alive in a global state)
    let record = recorder.load_from_bytes(weights_bytes, &device).expect("Failed to load weights in WASM");
    let model = model.load_record(record);

    // 3. Forward Pass
    let (out16, _) = model.forward(input);
    let out_reshaped = out16.reshape([1, 3, 10, 50, 50]);

    // 4. Post-processing (Simplified snippet for Demo)
    // In a full implementation, we'd copy the server.rs processing logic here
    let mut final_points = vec![0.0f32; 8];
    let mut max_conf = 0.5f32; // Mocking confidence for logic test

    let result = json!({
        "status": "success",
        "confidence": max_conf,
        "keypoints": final_points,
        "is_calibrated": true,
        "message": "Detected via Browser Neural Engine (WASM-WGPU)"
    });

    serde_wasm_bindgen::to_value(&result).unwrap()
}