File size: 2,265 Bytes
9874885
 
 
 
 
 
 
 
 
 
 
 
850d827
 
 
 
9874885
 
 
 
850d827
 
9874885
850d827
 
9874885
850d827
 
9874885
 
850d827
 
 
 
 
 
 
9874885
850d827
9874885
850d827
9874885
 
850d827
 
 
 
9874885
 
 
850d827
 
 
 
9874885
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;

pub fn diou_loss<B: Backend>(
    bboxes_pred: Tensor<B, 4>,
    target: Tensor<B, 4>,
) -> Tensor<B, 1> {
    // 1. Reshape to separate anchors: [Batch, 3, 10, H, W]
    let [batch, _channels, h, w] = bboxes_pred.dims();
    let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
    let t = target.reshape([batch, 3, 10, h, w]);
    
    // 2. Loss Constants
    let eps = 1e-6;

    // 3. Objectness Loss (BCE)
    let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
    let obj_target = t.clone().narrow(2, 4, 1);
    
    let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
    let neg_loss = obj_target.clone().neg().add_scalar(1.0)
        .mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
    
    // Weight positive samples heavily (sparsity)
    let obj_loss = pos_loss.mul_scalar(40.0).add(neg_loss).mean();

    // 4. Class Loss (Full BCE for all 5 channels)
    // bp channels 5-9: Dart, Cal1, Cal2, Cal3, Cal4
    let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
    let cls_target = t.clone().narrow(2, 5, 5);
    
    let cls_pos_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg();
    let cls_neg_loss = cls_target.clone().neg().add_scalar(1.0)
        .mul(cls_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
    
    let class_loss = cls_pos_loss.add(cls_neg_loss)
        .mul(obj_target.clone()) // Mask class loss where there is no object
        .mean()
        .mul_scalar(15.0);

    // 5. Box (XYWH) Loss (MSE)
    let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
    let b_xy_target = t.clone().narrow(2, 0, 2);
    let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0)
        .mul(obj_target.clone())
        .mean()
        .mul_scalar(30.0); // Increase weight for coordinate precision

    let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
    let b_wh_target = t.clone().narrow(2, 2, 2);
    let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0)
        .mul(obj_target)
        .mean()
        .mul_scalar(5.0);

    obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
}