kapil commited on
Commit
850d827
·
1 Parent(s): 90dd6a4

feat: implement DIoU loss function for object detection and update dataset documentation

Browse files
Files changed (3) hide show
  1. README.md +2 -1
  2. src/inference.rs +6 -6
  3. src/loss.rs +27 -15
README.md CHANGED
@@ -29,7 +29,8 @@ Using the **Burn Deep Learning Framework**, this project achieves sub-millisecon
29
  ## Dataset and Preparation
30
  The model is trained on the primary dataset used for high-precision dart detection.
31
 
32
- - **Download Link**: [Dataset Resources (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
 
33
  - **Resolution**: 800x800 pre-cropped high-resolution images.
34
  - **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
35
 
 
29
  ## Dataset and Preparation
30
  The model is trained on the primary dataset used for high-precision dart detection.
31
 
32
+ - **Model Weights Link**: [Neural Weights & TFLite (Google Drive)](https://drive.google.com/file/d/1ZEvuzg9zYbPd1FdZgV6v1aT4sqbqmLqp/view?usp=sharing)
33
+ - **Dataset Source**: [DeepDarts (IEEE Dataport)](https://ieee-dataport.org/open-access/deepdarts-dataset)
34
  - **Resolution**: 800x800 pre-cropped high-resolution images.
35
  - **Structure**: Organize your data in the `dataset/800/` directory following the provided `labels.json` schema.
36
 
src/inference.rs CHANGED
@@ -17,7 +17,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
17
 
18
  println!("🖼️ Processing image: {}...", image_path);
19
  let img = image::open(image_path).expect("Failed to open image");
20
- let resized = img.resize_exact(800, 800, image::imageops::FilterType::Triangle);
21
  let pixels: Vec<f32> = resized
22
  .to_rgb8()
23
  .pixels()
@@ -30,7 +30,7 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
30
  })
31
  .collect();
32
 
33
- let data = TensorData::new(pixels, [800, 800, 3]);
34
  let input = Tensor::<B, 3>::from_data(data, device)
35
  .unsqueeze::<4>()
36
  .permute([0, 3, 1, 2]);
@@ -38,12 +38,12 @@ pub fn run_inference<B: Backend>(device: &B::Device, image_path: &str) {
38
  println!("🚀 Running MODEL Prediction...");
39
  let (out16, _out32) = model.forward(input);
40
 
41
- // Post-process out16 (size [1, 30, 100, 100])
42
- // Decode objectness part (Channel 4 for Anchor 0)
43
  let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
44
 
45
- // Find highest confidence cell
46
- let (max_val, _) = obj.reshape([1, 10000]).max_dim_with_indices(1);
47
  let confidence: f32 = max_val
48
  .to_data()
49
  .convert::<f32>()
 
17
 
18
  println!("🖼️ Processing image: {}...", image_path);
19
  let img = image::open(image_path).expect("Failed to open image");
20
+ let resized = img.resize_exact(416, 416, image::imageops::FilterType::Triangle);
21
  let pixels: Vec<f32> = resized
22
  .to_rgb8()
23
  .pixels()
 
30
  })
31
  .collect();
32
 
33
+ let data = TensorData::new(pixels, [416, 416, 3]);
34
  let input = Tensor::<B, 3>::from_data(data, device)
35
  .unsqueeze::<4>()
36
  .permute([0, 3, 1, 2]);
 
38
  println!("🚀 Running MODEL Prediction...");
39
  let (out16, _out32) = model.forward(input);
40
 
41
+ // out16 shape: [1, 30, 26, 26]
42
+ // 1. Extract Objectness (Channel 4 of first anchor)
43
  let obj = burn::tensor::activation::sigmoid(out16.clone().narrow(1, 4, 1));
44
 
45
+ // 2. Find highest confidence cell in 26x26 grid
46
+ let (max_val, _) = obj.reshape([1, 676]).max_dim_with_indices(1);
47
  let confidence: f32 = max_val
48
  .to_data()
49
  .convert::<f32>()
src/loss.rs CHANGED
@@ -10,36 +10,48 @@ pub fn diou_loss<B: Backend>(
10
  let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
11
  let t = target.reshape([batch, 3, 10, h, w]);
12
 
13
- // 2. Objectness (Channel 4)
 
 
 
14
  let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
15
  let obj_target = t.clone().narrow(2, 4, 1);
16
 
17
- let eps = 1e-7;
18
- // Positive loss (where an object exists)
19
  let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
20
- // Negative loss (where no object exists)
21
- let neg_loss = obj_target.clone().neg().add_scalar(1.0).mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
22
 
23
- // Weight positive samples 10x more to fight imbalance (typical YOLO trick)
24
- let obj_loss = pos_loss.mul_scalar(20.0).add(neg_loss).mean();
25
 
26
- // 3. Class (Channels 5-9) - Only learn when object exists
 
27
  let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
28
  let cls_target = t.clone().narrow(2, 5, 5);
29
- let class_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg()
30
- .mul(obj_target.clone()) // Only count where object exists
 
 
 
 
 
31
  .mean()
32
- .mul_scalar(5.0); // Boost class learning
33
 
34
- // 4. Coordinates (Channels 0-3) - Only learn when object exists
35
- // 2. Coordinate Loss (MSE on relative offsets) - Weighted x10 for precision
36
  let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
37
  let b_xy_target = t.clone().narrow(2, 0, 2);
38
- let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0).mul(obj_target.clone()).mean().mul_scalar(10.0);
 
 
 
39
 
40
  let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
41
  let b_wh_target = t.clone().narrow(2, 2, 2);
42
- let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0).mul(obj_target).mean().mul_scalar(5.0);
 
 
 
43
 
44
  obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
45
  }
 
10
  let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
11
  let t = target.reshape([batch, 3, 10, h, w]);
12
 
13
+ // 2. Loss Constants
14
+ let eps = 1e-6;
15
+
16
+ // 3. Objectness Loss (BCE)
17
  let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
18
  let obj_target = t.clone().narrow(2, 4, 1);
19
 
 
 
20
  let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
21
+ let neg_loss = obj_target.clone().neg().add_scalar(1.0)
22
+ .mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
23
 
24
+ // Weight positive samples heavily (sparsity)
25
+ let obj_loss = pos_loss.mul_scalar(40.0).add(neg_loss).mean();
26
 
27
+ // 4. Class Loss (Full BCE for all 5 channels)
28
+ // bp channels 5-9: Dart, Cal1, Cal2, Cal3, Cal4
29
  let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
30
  let cls_target = t.clone().narrow(2, 5, 5);
31
+
32
+ let cls_pos_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg();
33
+ let cls_neg_loss = cls_target.clone().neg().add_scalar(1.0)
34
+ .mul(cls_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
35
+
36
+ let class_loss = cls_pos_loss.add(cls_neg_loss)
37
+ .mul(obj_target.clone()) // Mask class loss where there is no object
38
  .mean()
39
+ .mul_scalar(15.0);
40
 
41
+ // 5. Box (XYWH) Loss (MSE)
 
42
  let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
43
  let b_xy_target = t.clone().narrow(2, 0, 2);
44
+ let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0)
45
+ .mul(obj_target.clone())
46
+ .mean()
47
+ .mul_scalar(30.0); // Increase weight for coordinate precision
48
 
49
  let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
50
  let b_wh_target = t.clone().narrow(2, 2, 2);
51
+ let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0)
52
+ .mul(obj_target)
53
+ .mean()
54
+ .mul_scalar(5.0);
55
 
56
  obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
57
  }