Spaces:
Build error
Build error
kapil commited on
Commit ·
850d827
1
Parent(s): 90dd6a4
feat: implement DIoU loss function for object detection and update dataset documentation
Browse files- README.md +2 -1
- src/inference.rs +6 -6
- src/loss.rs +27 -15
README.md
CHANGED
|
@@ -29,7 +29,8 @@ Using the **Burn Deep Learning Framework**, this project achieves sub-millisecon
|
|
| 29 |
## Dataset and Preparation
|
| 30 |
The model is trained on the primary dataset used for high-precision dart detection.
|
| 31 |
|
| 32 |
-
- **
|
|
|
|
| 33 |
- **Resolution**: 800x800 pre-cropped high-resolution images.
|
| 34 |
- **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
|
| 35 |
|
|
|
|
| 29 |
## Dataset and Preparation
|
| 30 |
The model is trained on the primary dataset used for high-precision dart detection.
|
| 31 |
|
| 32 |
+
- **Model Weights Link**: [Neural Weights & TFLite (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
|
| 33 |
+
- **Dataset Source**: [DeepDarts (IEEE Dataport)](https://ieee-dataport.org/open-access/deepdarts-dataset)
|
| 34 |
- **Resolution**: 800x800 pre-cropped high-resolution images.
|
| 35 |
- **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
|
| 36 |
|
src/inference.rs
CHANGED
|
@@ -17,7 +17,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 17 |
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
-
let resized = img.resize_exact(
|
| 21 |
let pixels: Vec<f32> = resized
|
| 22 |
.to_rgb8()
|
| 23 |
.pixels()
|
|
@@ -30,7 +30,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 30 |
})
|
| 31 |
.collect();
|
| 32 |
|
| 33 |
-
let data = TensorData::new(pixels, [
|
| 34 |
let input = Tensor::<B, 3>::from_data(data, device)
|
| 35 |
.unsqueeze::<4>()
|
| 36 |
.permute([0, 3, 1, 2]);
|
|
@@ -38,12 +38,12 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
|
|
| 38 |
println!("🚀 Running MODEL Prediction...");
|
| 39 |
let (out16, _out32) = model.forward(input);
|
| 40 |
|
| 41 |
-
//
|
| 42 |
-
//
|
| 43 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 44 |
|
| 45 |
-
// Find highest confidence cell
|
| 46 |
-
let (max_val, _) = obj.reshape([1,
|
| 47 |
let confidence: f32 = max_val
|
| 48 |
.to_data()
|
| 49 |
.convert::<f32>()
|
|
|
|
| 17 |
|
| 18 |
println!("🖼️ Processing image: {}...", image_path);
|
| 19 |
let img = image::open(image_path).expect("Failed to open image");
|
| 20 |
+
let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
|
| 21 |
let pixels: Vec<f32> = resized
|
| 22 |
.to_rgb8()
|
| 23 |
.pixels()
|
|
|
|
| 30 |
})
|
| 31 |
.collect();
|
| 32 |
|
| 33 |
+
let data = TensorData::new(pixels, [416, 416, 3]);
|
| 34 |
let input = Tensor::<B, 3>::from_data(data, device)
|
| 35 |
.unsqueeze::<4>()
|
| 36 |
.permute([0, 3, 1, 2]);
|
|
|
|
| 38 |
println!("🚀 Running MODEL Prediction...");
|
| 39 |
let (out16, _out32) = model.forward(input);
|
| 40 |
|
| 41 |
+
// out16 shape: [1, 30, 26, 26]
|
| 42 |
+
// 1. Extract Objectness (Channel 4 of first anchor)
|
| 43 |
let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
|
| 44 |
|
| 45 |
+
// 2. Find highest confidence cell in 26x26 grid
|
| 46 |
+
let (max_val, _) = obj.reshape([1, 676]).max_dim_with_indices(1);
|
| 47 |
let confidence: f32 = max_val
|
| 48 |
.to_data()
|
| 49 |
.convert::<f32>()
|
src/loss.rs
CHANGED
|
@@ -10,36 +10,48 @@ pub fn diou_loss<B: Backend>(
|
|
| 10 |
let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
|
| 11 |
let t = target.reshape([batch, 3, 10, h, w]);
|
| 12 |
|
| 13 |
-
// 2.
|
|
|
|
|
|
|
|
|
|
| 14 |
let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
|
| 15 |
let obj_target = t.clone().narrow(2, 4, 1);
|
| 16 |
|
| 17 |
-
let eps = 1e-7;
|
| 18 |
-
// Positive loss (where an object exists)
|
| 19 |
let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
// Weight positive samples
|
| 24 |
-
let obj_loss = pos_loss.mul_scalar(
|
| 25 |
|
| 26 |
-
//
|
|
|
|
| 27 |
let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
|
| 28 |
let cls_target = t.clone().narrow(2, 5, 5);
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
.mean()
|
| 32 |
-
.mul_scalar(
|
| 33 |
|
| 34 |
-
//
|
| 35 |
-
// 2. Coordinate Loss (MSE on relative offsets) - Weighted x10 for precision
|
| 36 |
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
|
| 37 |
let b_xy_target = t.clone().narrow(2, 0, 2);
|
| 38 |
-
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0)
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
|
| 41 |
let b_wh_target = t.clone().narrow(2, 2, 2);
|
| 42 |
-
let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0)
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
|
| 45 |
}
|
|
|
|
| 10 |
let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
|
| 11 |
let t = target.reshape([batch, 3, 10, h, w]);
|
| 12 |
|
| 13 |
+
// 2. Loss Constants
|
| 14 |
+
let eps = 1e-6;
|
| 15 |
+
|
| 16 |
+
// 3. Objectness Loss (BCE)
|
| 17 |
let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
|
| 18 |
let obj_target = t.clone().narrow(2, 4, 1);
|
| 19 |
|
|
|
|
|
|
|
| 20 |
let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
|
| 21 |
+
let neg_loss = obj_target.clone().neg().add_scalar(1.0)
|
| 22 |
+
.mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
|
| 23 |
|
| 24 |
+
// Weight positive samples heavily (sparsity)
|
| 25 |
+
let obj_loss = pos_loss.mul_scalar(40.0).add(neg_loss).mean();
|
| 26 |
|
| 27 |
+
// 4. Class Loss (Full BCE for all 5 channels)
|
| 28 |
+
// bp channels 5-9: Dart, Cal1, Cal2, Cal3, Cal4
|
| 29 |
let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
|
| 30 |
let cls_target = t.clone().narrow(2, 5, 5);
|
| 31 |
+
|
| 32 |
+
let cls_pos_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg();
|
| 33 |
+
let cls_neg_loss = cls_target.clone().neg().add_scalar(1.0)
|
| 34 |
+
.mul(cls_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
|
| 35 |
+
|
| 36 |
+
let class_loss = cls_pos_loss.add(cls_neg_loss)
|
| 37 |
+
.mul(obj_target.clone()) // Mask class loss where there is no object
|
| 38 |
.mean()
|
| 39 |
+
.mul_scalar(15.0);
|
| 40 |
|
| 41 |
+
// 5. Box (XYWH) Loss (MSE)
|
|
|
|
| 42 |
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
|
| 43 |
let b_xy_target = t.clone().narrow(2, 0, 2);
|
| 44 |
+
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0)
|
| 45 |
+
.mul(obj_target.clone())
|
| 46 |
+
.mean()
|
| 47 |
+
.mul_scalar(30.0); // Increase weight for coordinate precision
|
| 48 |
|
| 49 |
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
|
| 50 |
let b_wh_target = t.clone().narrow(2, 2, 2);
|
| 51 |
+
let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0)
|
| 52 |
+
.mul(obj_target)
|
| 53 |
+
.mean()
|
| 54 |
+
.mul_scalar(5.0);
|
| 55 |
|
| 56 |
obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
|
| 57 |
}
|