Spaces:
Build error
Build error
File size: 2,397 Bytes
a6578e7 9874885 90dd6a4 9874885 a6578e7 9874885 8e03aff a6578e7 9874885 8e03aff a6578e7 9874885 8e03aff 9874885 a6578e7 8e03aff a6578e7 9874885 a6578e7 9874885 a6578e7 9874885 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | use crate::model::DartVisionModel;
use burn::module::Module;
use burn::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, TensorData};
use image;
pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
println!("๐ Loading model for inference...");
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
let model: DartVisionModel<B> = DartVisionModel::new(device);
// Load weights
let record = Recorder::load(&recorder, "model_weights".into(), device)
.expect("Failed to load weights. Make sure model_weights.bin exists.");
let model = model.load_record(record);
println!("๐ผ๏ธ Processing image: {}...", image_path);
let img = image::open(image_path).expect("Failed to open image");
let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
let pixels: Vec<f32> = resized
.to_rgb8()
.pixels()
.flat_map(|p| {
vec![
p[0] as f32 / 255.0,
p[1] as f32 / 255.0,
p[2] as f32 / 255.0,
]
})
.collect();
let data = TensorData::new(pixels, [800, 800, 3]);
let input = Tensor::<B, 3>::from_data(data, device)
.unsqueeze::<4>()
.permute([0, 3, 1, 2]);
println!("๐ Running MODEL Prediction...");
let (out16, _out32) = model.forward(input);
// out16 shape: [1, 30, 50, 50] โ 800/16 = 50
// Extract Objectness (Channel 4 of first anchor)
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
// Find highest confidence cell in 50x50 grid
let (max_val, _) = obj.reshape([1_usize, 2500]).max_dim_with_indices(1);
let confidence: f32 = max_val
.to_data()
.convert::<f32>()
.as_slice::<f32>()
.unwrap()[0];
println!("--------------------------------------------------");
println!("๐ RESULTS FOR: {}", image_path);
println!("โจ Max Objectness: {:.2}%", confidence * 100.0);
if confidence > 0.05 {
println!(
"โ
Model found something! Confidence Score: {:.4}",
confidence
);
} else {
println!("โ ๏ธ Model confidence is too low. Training incomplete?");
}
println!("--------------------------------------------------");
}
|