use burn::tensor::backend::Backend; use burn::tensor::Tensor; pub fn diou_loss( bboxes_pred: Tensor, target: Tensor, ) -> Tensor { // 1. Reshape to separate anchors: [Batch, 3, 10, H, W] let [batch, _channels, h, w] = bboxes_pred.dims(); let bp = bboxes_pred.reshape([batch, 3, 10, h, w]); let t = target.reshape([batch, 3, 10, h, w]); // 2. Loss Constants let eps = 1e-6; // 3. Objectness Loss (BCE) let obj_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 4, 1)); let obj_target = t.clone().narrow(2, 4, 1); let pos_loss = obj_target.clone().mul(obj_pred.clone().add_scalar(eps).log()).neg(); let neg_loss = obj_target.clone().neg().add_scalar(1.0) .mul(obj_pred.clone().neg().add_scalar(1.0 + eps).log()).neg(); // Weight positive samples heavily (sparsity) let obj_loss = pos_loss.mul_scalar(40.0).add(neg_loss).mean(); // 4. Class Loss (Full BCE for all 5 channels) // bp channels 5-9: Dart, Cal1, Cal2, Cal3, Cal4 let cls_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 5, 5)); let cls_target = t.clone().narrow(2, 5, 5); let cls_pos_loss = cls_target.clone().mul(cls_pred.clone().add_scalar(eps).log()).neg(); let cls_neg_loss = cls_target.clone().neg().add_scalar(1.0) .mul(cls_pred.clone().neg().add_scalar(1.0 + eps).log()).neg(); let class_loss = cls_pos_loss.add(cls_neg_loss) .mul(obj_target.clone()) // Mask class loss where there is no object .mean() .mul_scalar(15.0); // 5. Box (XYWH) Loss (MSE) let b_xy_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 0, 2)); let b_xy_target = t.clone().narrow(2, 0, 2); let xy_loss = b_xy_pred.sub(b_xy_target).powf_scalar(2.0) .mul(obj_target.clone()) .mean() .mul_scalar(30.0); // Increase weight for coordinate precision let b_wh_pred = burn::tensor::activation::sigmoid(bp.clone().narrow(2, 2, 2)); let b_wh_target = t.clone().narrow(2, 2, 2); let wh_loss = b_wh_pred.sub(b_wh_target).powf_scalar(2.0) .mul(obj_target) .mean() .mul_scalar(5.0); obj_loss.add(class_loss).add(xy_loss).add(wh_loss) }