kapil
feat: implement DIoU loss function for object detection and update dataset documentation
850d827
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
pub fn diou_loss<B: Backend>(
bboxes_pred: Tensor<B, 4>,
target: Tensor<B, 4>,
) -> Tensor<B, 1> {
// 1. Reshape to separate anchors: [Batch, 3, 10, H, W]
let [batch, _channels, h, w] = bboxes_pred.dims();
let bp = bboxes_pred.reshape([batch, 3, 10, h, w]);
let t = target.reshape([batch, 3, 10, h, w]);
// 2. Loss Constants
let eps = 1e-6;
// 3. Objectness Loss (BCE)
let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1));
let obj_target = t.clone().narrow(2, 4, 1);
let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg();
let neg_loss = obj_target.clone().neg().add_scalar(1.0)
.mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
// Weight positive samples heavily (sparsity)
let obj_loss = pos_loss.mul_scalar(40.0).add(neg_loss).mean();
// 4. Class Loss (Full BCE for all 5 channels)
// bp channels 5-9: Dart, Cal1, Cal2, Cal3, Cal4
let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5));
let cls_target = t.clone().narrow(2, 5, 5);
let cls_pos_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg();
let cls_neg_loss = cls_target.clone().neg().add_scalar(1.0)
.mul(cls_pred.clone().neg().add_scalar(1.0 + eps).log()).neg();
let class_loss = cls_pos_loss.add(cls_neg_loss)
.mul(obj_target.clone()) // Mask class loss where there is no object
.mean()
.mul_scalar(15.0);
// 5. Box (XYWH) Loss (MSE)
let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2));
let b_xy_target = t.clone().narrow(2, 0, 2);
let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0)
.mul(obj_target.clone())
.mean()
.mul_scalar(30.0); // Increase weight for coordinate precision
let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2));
let b_wh_target = t.clone().narrow(2, 2, 2);
let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0)
.mul(obj_target)
.mean()
.mul_scalar(5.0);
obj_loss.add(class_loss).add(xy_loss).add(wh_loss)
}