| |
| |
| |
| |
| |
| |
| |
| import torch |
| from src.model import WindowDiffLoss |
|
|
|
|
| def precision_recall_f1_wd( |
| y_hat: torch.Tensor, |
| y: torch.Tensor, |
| y_mask: torch.Tensor, |
| wd_object: object = None, |
| ) -> tuple[float, float, float, float]: |
| """ |
| Computes precision, recall, and F1 score for binary classification |
| with masking. |
| |
| Args: |
| y_hat: Logits or scores (any shape). |
| y: Ground-truth binary labels (same shape as y_hat). |
| y_mask: Boolean or {0,1} mask indicating valid elements. |
| wd_object: A callable object that computes window diff. |
| |
| Returns: |
| (precision, recall, f1, window_diff) as Python floats. |
| """ |
| |
| mask = y_mask.bool() |
| preds = y_hat & mask |
| targets = (y.bool()) & mask |
|
|
| tp = (preds & targets).sum().item() |
| fp = (preds & ~targets).sum().item() |
| fn = (~preds & targets).sum().item() |
|
|
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = ( |
| 2 * precision * recall / (precision + recall) |
| if (precision + recall) > 0 |
| else 0.0 |
| ) |
|
|
| |
| if wd_object is None: |
| wd_object = WindowDiffLoss(k=6) |
| wd = wd_object(y_hat, y, y_mask).item() |
|
|
| return precision, recall, f1, wd |
| |
| |
| |
|
|