File size: 2,397 Bytes
a6578e7
9874885
 
 
 
90dd6a4
9874885
 
 
 
 
a6578e7
9874885
 
 
 
 
 
 
8e03aff
a6578e7
 
 
 
 
 
 
 
 
 
9874885
 
8e03aff
a6578e7
 
 
9874885
 
 
 
8e03aff
 
9874885
a6578e7
8e03aff
 
a6578e7
 
 
 
 
 
9874885
 
 
a6578e7
9874885
a6578e7
 
 
 
9874885
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use crate::model::DartVisionModel;
use burn::module::Module;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, TensorData};
use image;

pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
    println!("๐Ÿ” Loading model for inference...");
    let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
    let model: DartVisionModel<B> = DartVisionModel::new(device);

    // Load weights
    let record = Recorder::load(&recorder, "model_weights".into(), device)
        .expect("Failed to load weights. Make sure model_weights.bin exists.");
    let model = model.load_record(record);

    println!("๐Ÿ–ผ๏ธ Processing image: {}...", image_path);
    let img = image::open(image_path).expect("Failed to open image");
    let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
    let pixels: Vec<f32> = resized
        .to_rgb8()
        .pixels()
        .flat_map(|p| {
            vec![
                p[0] as f32 / 255.0,
                p[1] as f32 / 255.0,
                p[2] as f32 / 255.0,
            ]
        })
        .collect();

    let data = TensorData::new(pixels, [800, 800, 3]);
    let input = Tensor::<B, 3>::from_data(data, device)
        .unsqueeze::<4>()
        .permute([0, 3, 1, 2]);

    println!("๐Ÿš€ Running MODEL Prediction...");
    let (out16, _out32) = model.forward(input);

    // out16 shape: [1, 30, 50, 50] โ€” 800/16 = 50
    // Extract Objectness (Channel 4 of first anchor)
    let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));

    // Find highest confidence cell in 50x50 grid
    let (max_val, _) = obj.reshape([1_usize, 2500]).max_dim_with_indices(1);
    let confidence: f32 = max_val
        .to_data()
        .convert::<f32>()
        .as_slice::<f32>()
        .unwrap()[0];

    println!("--------------------------------------------------");
    println!("๐Ÿ“Š RESULTS FOR: {}", image_path);
    println!("โœจ Max Objectness: {:.2}%", confidence * 100.0);

    if confidence > 0.05 {
        println!(
            "โœ… Model found something! Confidence Score: {:.4}",
            confidence
        );
    } else {
        println!("โš ๏ธ Model confidence is too low. Training incomplete?");
    }
    println!("--------------------------------------------------");
}