""" Accuracy and per-feature metric tracking (F1, AP) for training and evaluation. """ from typing import Dict, List, Optional import torch def compute_accuracy( logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor, ) -> torch.Tensor: """Compute masked accuracy.""" predictions = logits.argmax(dim=-1) # (B, N) correct = (predictions == targets) & mask accuracy = correct.sum().float() / mask.sum().float() return accuracy try: from torchmetrics.classification import ( MulticlassAccuracy, MulticlassF1Score, MulticlassAveragePrecision, ) TORCHMETRICS_AVAILABLE = True except ImportError: TORCHMETRICS_AVAILABLE = False class FeatureMetrics: """ Tracks precision, recall, F1, and accuracy for each feature using torchmetrics. Uses macro-averaging across classes, which is better for imbalanced data as it gives equal weight to all classes regardless of frequency. """ # Skip AP for features with vocab > this threshold (too memory-intensive) AP_MAX_VOCAB_SIZE = 100 def __init__( self, feature_configs, device: torch.device, top_k: int = 3, ): self.feature_names = [f.name for f in feature_configs] self.vocab_sizes = {f.name: f.vocab_size for f in feature_configs} self.device = device self.top_k = top_k if not TORCHMETRICS_AVAILABLE: self.metrics = None return # Create metrics for each feature # AP is memory-heavy (stores all predictions), so skip for large vocab features self.metrics = {} for feat in feature_configs: num_classes = feat.vocab_size self.metrics[feat.name] = { 'accuracy': MulticlassAccuracy( num_classes=num_classes, average='micro' ).to(device), 'f1_macro': MulticlassF1Score( num_classes=num_classes, average='macro' ).to(device), } # Only add AP for small vocab features (memory constraint) # AP stores all predictions which is ~O(N * C) memory if num_classes <= self.AP_MAX_VOCAB_SIZE: self.metrics[feat.name]['ap_macro'] = MulticlassAveragePrecision( num_classes=num_classes, average='macro' ).to(device) # Add top-k accuracy for large vocab features (duration, etc.) if num_classes > 10: k = min(top_k, num_classes) self.metrics[feat.name]['accuracy_top5'] = MulticlassAccuracy( num_classes=num_classes, average='micro', top_k=k ).to(device) def reset(self): """Reset all metrics for new epoch.""" if self.metrics is None: return for feat_metrics in self.metrics.values(): for metric in feat_metrics.values(): metric.reset() def update( self, outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], mask: torch.Tensor, ): """Update metrics with batch predictions.""" if self.metrics is None: return # Flatten mask mask_flat = mask.view(-1).bool() for feat_name in self.feature_names: if feat_name not in outputs: continue # Get predictions and targets, apply mask logits = outputs[feat_name] # (B, N, C) tgt = targets[feat_name] # (B, N) # Flatten and mask B, N, C = logits.shape logits_flat = logits.view(-1, C) # (B*N, C) tgt_flat = tgt.view(-1) # (B*N,) # Select only valid (unmasked) positions valid_logits = logits_flat[mask_flat] # (num_valid, C) valid_targets = tgt_flat[mask_flat] # (num_valid,) if valid_targets.numel() == 0: continue # Update all metrics for this feature for metric in self.metrics[feat_name].values(): metric.update(valid_logits, valid_targets) def compute(self) -> Dict[str, Dict[str, float]]: """Compute final metrics for all features.""" if self.metrics is None: return {} results = {} for feat_name, feat_metrics in self.metrics.items(): results[feat_name] = {} for metric_name, metric in feat_metrics.items(): try: results[feat_name][metric_name] = metric.compute().item() except Exception: results[feat_name][metric_name] = 0.0 return results def compute_summary(self) -> Dict[str, float]: """Compute summary metrics averaged across features.""" if self.metrics is None: return {} per_feature = self.compute() if not per_feature: return {} summary = {} metric_names = ['accuracy', 'f1_macro', 'ap_macro'] for metric_name in metric_names: values = [per_feature[f].get(metric_name, 0.0) for f in self.feature_names if metric_name in per_feature.get(f, {})] if values: summary[f'avg_{metric_name}'] = sum(values) / len(values) return summary