Spaces:
Sleeping
Sleeping
| """ | |
| Accuracy and per-feature metric tracking (F1, AP) for training and evaluation. | |
| """ | |
| from typing import Dict, List, Optional | |
| import torch | |
| def compute_accuracy( | |
| logits: torch.Tensor, | |
| targets: torch.Tensor, | |
| mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Compute masked accuracy.""" | |
| predictions = logits.argmax(dim=-1) # (B, N) | |
| correct = (predictions == targets) & mask | |
| accuracy = correct.sum().float() / mask.sum().float() | |
| return accuracy | |
| try: | |
| from torchmetrics.classification import ( | |
| MulticlassAccuracy, | |
| MulticlassF1Score, | |
| MulticlassAveragePrecision, | |
| ) | |
| TORCHMETRICS_AVAILABLE = True | |
| except ImportError: | |
| TORCHMETRICS_AVAILABLE = False | |
| class FeatureMetrics: | |
| """ | |
| Tracks precision, recall, F1, and accuracy for each feature using torchmetrics. | |
| Uses macro-averaging across classes, which is better for imbalanced data | |
| as it gives equal weight to all classes regardless of frequency. | |
| """ | |
| # Skip AP for features with vocab > this threshold (too memory-intensive) | |
| AP_MAX_VOCAB_SIZE = 100 | |
| def __init__( | |
| self, | |
| feature_configs, | |
| device: torch.device, | |
| top_k: int = 3, | |
| ): | |
| self.feature_names = [f.name for f in feature_configs] | |
| self.vocab_sizes = {f.name: f.vocab_size for f in feature_configs} | |
| self.device = device | |
| self.top_k = top_k | |
| if not TORCHMETRICS_AVAILABLE: | |
| self.metrics = None | |
| return | |
| # Create metrics for each feature | |
| # AP is memory-heavy (stores all predictions), so skip for large vocab features | |
| self.metrics = {} | |
| for feat in feature_configs: | |
| num_classes = feat.vocab_size | |
| self.metrics[feat.name] = { | |
| 'accuracy': MulticlassAccuracy( | |
| num_classes=num_classes, average='micro' | |
| ).to(device), | |
| 'f1_macro': MulticlassF1Score( | |
| num_classes=num_classes, average='macro' | |
| ).to(device), | |
| } | |
| # Only add AP for small vocab features (memory constraint) | |
| # AP stores all predictions which is ~O(N * C) memory | |
| if num_classes <= self.AP_MAX_VOCAB_SIZE: | |
| self.metrics[feat.name]['ap_macro'] = MulticlassAveragePrecision( | |
| num_classes=num_classes, average='macro' | |
| ).to(device) | |
| # Add top-k accuracy for large vocab features (duration, etc.) | |
| if num_classes > 10: | |
| k = min(top_k, num_classes) | |
| self.metrics[feat.name]['accuracy_top5'] = MulticlassAccuracy( | |
| num_classes=num_classes, average='micro', top_k=k | |
| ).to(device) | |
| def reset(self): | |
| """Reset all metrics for new epoch.""" | |
| if self.metrics is None: | |
| return | |
| for feat_metrics in self.metrics.values(): | |
| for metric in feat_metrics.values(): | |
| metric.reset() | |
| def update( | |
| self, | |
| outputs: Dict[str, torch.Tensor], | |
| targets: Dict[str, torch.Tensor], | |
| mask: torch.Tensor, | |
| ): | |
| """Update metrics with batch predictions.""" | |
| if self.metrics is None: | |
| return | |
| # Flatten mask | |
| mask_flat = mask.view(-1).bool() | |
| for feat_name in self.feature_names: | |
| if feat_name not in outputs: | |
| continue | |
| # Get predictions and targets, apply mask | |
| logits = outputs[feat_name] # (B, N, C) | |
| tgt = targets[feat_name] # (B, N) | |
| # Flatten and mask | |
| B, N, C = logits.shape | |
| logits_flat = logits.view(-1, C) # (B*N, C) | |
| tgt_flat = tgt.view(-1) # (B*N,) | |
| # Select only valid (unmasked) positions | |
| valid_logits = logits_flat[mask_flat] # (num_valid, C) | |
| valid_targets = tgt_flat[mask_flat] # (num_valid,) | |
| if valid_targets.numel() == 0: | |
| continue | |
| # Update all metrics for this feature | |
| for metric in self.metrics[feat_name].values(): | |
| metric.update(valid_logits, valid_targets) | |
| def compute(self) -> Dict[str, Dict[str, float]]: | |
| """Compute final metrics for all features.""" | |
| if self.metrics is None: | |
| return {} | |
| results = {} | |
| for feat_name, feat_metrics in self.metrics.items(): | |
| results[feat_name] = {} | |
| for metric_name, metric in feat_metrics.items(): | |
| try: | |
| results[feat_name][metric_name] = metric.compute().item() | |
| except Exception: | |
| results[feat_name][metric_name] = 0.0 | |
| return results | |
| def compute_summary(self) -> Dict[str, float]: | |
| """Compute summary metrics averaged across features.""" | |
| if self.metrics is None: | |
| return {} | |
| per_feature = self.compute() | |
| if not per_feature: | |
| return {} | |
| summary = {} | |
| metric_names = ['accuracy', 'f1_macro', 'ap_macro'] | |
| for metric_name in metric_names: | |
| values = [per_feature[f].get(metric_name, 0.0) for f in self.feature_names | |
| if metric_name in per_feature.get(f, {})] | |
| if values: | |
| summary[f'avg_{metric_name}'] = sum(values) / len(values) | |
| return summary | |