alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
from src.model import WindowDiffLoss
def precision_recall_f1_wd(
y_hat: torch.Tensor,
y: torch.Tensor,
y_mask: torch.Tensor,
wd_object: object = None,
) -> tuple[float, float, float, float]:
"""
Computes precision, recall, and F1 score for binary classification
with masking.
Args:
y_hat: Logits or scores (any shape).
y: Ground-truth binary labels (same shape as y_hat).
y_mask: Boolean or {0,1} mask indicating valid elements.
wd_object: A callable object that computes window diff.
Returns:
(precision, recall, f1, window_diff) as Python floats.
"""
# Ensure boolean tensors
mask = y_mask.bool()
preds = y_hat & mask
targets = (y.bool()) & mask
tp = (preds & targets).sum().item()
fp = (preds & ~targets).sum().item()
fn = (~preds & targets).sum().item()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
# Window diff:
if wd_object is None:
wd_object = WindowDiffLoss(k=6)
wd = wd_object(y_hat, y, y_mask).item()
return precision, recall, f1, wd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #