""" Smoke test for src/training/losses.py and src/training/metrics.py. Run from project root: .\\venv\\Scripts\\python.exe scripts\\test_training_utils.py """ from __future__ import annotations import sys from pathlib import Path import numpy as np import torch ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.training.losses import FocalLoss, WeightedBCEWithLogitsLoss, get_loss_function # noqa: E402 from src.training.metrics import ( # noqa: E402 compute_metrics, find_optimal_thresholds, format_metrics_table, ) def main() -> None: torch.manual_seed(7) np.random.seed(7) n_samples = 64 n_labels = 11 label_names = [ "Membrane", "Cytoplasm", "Nucleus", "Extracellular", "Cell membrane", "Mitochondrion", "Plastid", "Endoplasmic reticulum", "Lysosome/Vacuole", "Golgi apparatus", "Peroxisome", ] # Synthetic logits / multilabel targets logits = torch.randn(n_samples, n_labels, dtype=torch.float32) y_true = (torch.rand(n_samples, n_labels) > 0.7).float() pos_weights = torch.linspace(1.0, 3.0, n_labels) print("=== Losses smoke test ===") bce_loss_fn = WeightedBCEWithLogitsLoss(pos_weights) focal_loss_fn = FocalLoss(gamma=2.0, alpha=0.25) factory_bce = get_loss_function("bce", pos_weights=pos_weights, device="cpu") factory_focal = get_loss_function("focal", pos_weights=pos_weights, device="cpu") bce_loss = bce_loss_fn(logits, y_true) focal_loss = focal_loss_fn(logits, y_true) bce_loss2 = factory_bce(logits, y_true) focal_loss2 = factory_focal(logits, y_true) print(f"Weighted BCE loss: {float(bce_loss):.6f}") print(f"Focal loss: {float(focal_loss):.6f}") print(f"Factory BCE loss: {float(bce_loss2):.6f}") print(f"Factory focal: {float(focal_loss2):.6f}") # Metrics inputs y_proba = torch.sigmoid(logits).numpy() y_true_np = y_true.numpy().astype(np.int64) y_bin = (y_proba >= 0.5).astype(np.int64) print("\n=== Metrics smoke test ===") metrics = compute_metrics( y_true=y_true_np, y_pred_proba=y_proba, y_pred_binary=y_bin, label_names=label_names, ) print( f"macro_f1={metrics['macro_f1']:.4f}, micro_f1={metrics['micro_f1']:.4f}, " f"subset_accuracy={metrics['subset_accuracy']:.4f}, hamming_loss={metrics['hamming_loss']:.4f}" ) thresholds = find_optimal_thresholds( y_true=y_true_np, y_pred_proba=y_proba, label_names=label_names, ) print(f"Found thresholds for {len(thresholds)} labels. Example: Membrane={thresholds['Membrane']:.2f}") table = format_metrics_table(metrics, label_names) print("\n" + table) print("\nOK - losses/metrics smoke test passed.") if __name__ == "__main__": main()