Spaces:
Running
Running
File size: 2,947 Bytes
cb6f1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | """
Smoke test for src/training/losses.py and src/training/metrics.py.
Run from project root:
.\\venv\\Scripts\\python.exe scripts\\test_training_utils.py
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
import torch
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.training.losses import FocalLoss, WeightedBCEWithLogitsLoss, get_loss_function # noqa: E402
from src.training.metrics import ( # noqa: E402
compute_metrics,
find_optimal_thresholds,
format_metrics_table,
)
def main() -> None:
torch.manual_seed(7)
np.random.seed(7)
n_samples = 64
n_labels = 11
label_names = [
"Membrane",
"Cytoplasm",
"Nucleus",
"Extracellular",
"Cell membrane",
"Mitochondrion",
"Plastid",
"Endoplasmic reticulum",
"Lysosome/Vacuole",
"Golgi apparatus",
"Peroxisome",
]
# Synthetic logits / multilabel targets
logits = torch.randn(n_samples, n_labels, dtype=torch.float32)
y_true = (torch.rand(n_samples, n_labels) > 0.7).float()
pos_weights = torch.linspace(1.0, 3.0, n_labels)
print("=== Losses smoke test ===")
bce_loss_fn = WeightedBCEWithLogitsLoss(pos_weights)
focal_loss_fn = FocalLoss(gamma=2.0, alpha=0.25)
factory_bce = get_loss_function("bce", pos_weights=pos_weights, device="cpu")
factory_focal = get_loss_function("focal", pos_weights=pos_weights, device="cpu")
bce_loss = bce_loss_fn(logits, y_true)
focal_loss = focal_loss_fn(logits, y_true)
bce_loss2 = factory_bce(logits, y_true)
focal_loss2 = factory_focal(logits, y_true)
print(f"Weighted BCE loss: {float(bce_loss):.6f}")
print(f"Focal loss: {float(focal_loss):.6f}")
print(f"Factory BCE loss: {float(bce_loss2):.6f}")
print(f"Factory focal: {float(focal_loss2):.6f}")
# Metrics inputs
y_proba = torch.sigmoid(logits).numpy()
y_true_np = y_true.numpy().astype(np.int64)
y_bin = (y_proba >= 0.5).astype(np.int64)
print("\n=== Metrics smoke test ===")
metrics = compute_metrics(
y_true=y_true_np,
y_pred_proba=y_proba,
y_pred_binary=y_bin,
label_names=label_names,
)
print(
f"macro_f1={metrics['macro_f1']:.4f}, micro_f1={metrics['micro_f1']:.4f}, "
f"subset_accuracy={metrics['subset_accuracy']:.4f}, hamming_loss={metrics['hamming_loss']:.4f}"
)
thresholds = find_optimal_thresholds(
y_true=y_true_np,
y_pred_proba=y_proba,
label_names=label_names,
)
print(f"Found thresholds for {len(thresholds)} labels. Example: Membrane={thresholds['Membrane']:.2f}")
table = format_metrics_table(metrics, label_names)
print("\n" + table)
print("\nOK - losses/metrics smoke test passed.")
if __name__ == "__main__":
main()
|