score-ae / src /utils /metrics.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
5.56 kB
"""
Accuracy and per-feature metric tracking (F1, AP) for training and evaluation.
"""
from typing import Dict, List, Optional
import torch
def compute_accuracy(
logits: torch.Tensor,
targets: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor:
"""Compute masked accuracy."""
predictions = logits.argmax(dim=-1) # (B, N)
correct = (predictions == targets) & mask
accuracy = correct.sum().float() / mask.sum().float()
return accuracy
try:
from torchmetrics.classification import (
MulticlassAccuracy,
MulticlassF1Score,
MulticlassAveragePrecision,
)
TORCHMETRICS_AVAILABLE = True
except ImportError:
TORCHMETRICS_AVAILABLE = False
class FeatureMetrics:
"""
Tracks precision, recall, F1, and accuracy for each feature using torchmetrics.
Uses macro-averaging across classes, which is better for imbalanced data
as it gives equal weight to all classes regardless of frequency.
"""
# Skip AP for features with vocab > this threshold (too memory-intensive)
AP_MAX_VOCAB_SIZE = 100
def __init__(
self,
feature_configs,
device: torch.device,
top_k: int = 3,
):
self.feature_names = [f.name for f in feature_configs]
self.vocab_sizes = {f.name: f.vocab_size for f in feature_configs}
self.device = device
self.top_k = top_k
if not TORCHMETRICS_AVAILABLE:
self.metrics = None
return
# Create metrics for each feature
# AP is memory-heavy (stores all predictions), so skip for large vocab features
self.metrics = {}
for feat in feature_configs:
num_classes = feat.vocab_size
self.metrics[feat.name] = {
'accuracy': MulticlassAccuracy(
num_classes=num_classes, average='micro'
).to(device),
'f1_macro': MulticlassF1Score(
num_classes=num_classes, average='macro'
).to(device),
}
# Only add AP for small vocab features (memory constraint)
# AP stores all predictions which is ~O(N * C) memory
if num_classes <= self.AP_MAX_VOCAB_SIZE:
self.metrics[feat.name]['ap_macro'] = MulticlassAveragePrecision(
num_classes=num_classes, average='macro'
).to(device)
# Add top-k accuracy for large vocab features (duration, etc.)
if num_classes > 10:
k = min(top_k, num_classes)
self.metrics[feat.name]['accuracy_top5'] = MulticlassAccuracy(
num_classes=num_classes, average='micro', top_k=k
).to(device)
def reset(self):
"""Reset all metrics for new epoch."""
if self.metrics is None:
return
for feat_metrics in self.metrics.values():
for metric in feat_metrics.values():
metric.reset()
def update(
self,
outputs: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
mask: torch.Tensor,
):
"""Update metrics with batch predictions."""
if self.metrics is None:
return
# Flatten mask
mask_flat = mask.view(-1).bool()
for feat_name in self.feature_names:
if feat_name not in outputs:
continue
# Get predictions and targets, apply mask
logits = outputs[feat_name] # (B, N, C)
tgt = targets[feat_name] # (B, N)
# Flatten and mask
B, N, C = logits.shape
logits_flat = logits.view(-1, C) # (B*N, C)
tgt_flat = tgt.view(-1) # (B*N,)
# Select only valid (unmasked) positions
valid_logits = logits_flat[mask_flat] # (num_valid, C)
valid_targets = tgt_flat[mask_flat] # (num_valid,)
if valid_targets.numel() == 0:
continue
# Update all metrics for this feature
for metric in self.metrics[feat_name].values():
metric.update(valid_logits, valid_targets)
def compute(self) -> Dict[str, Dict[str, float]]:
"""Compute final metrics for all features."""
if self.metrics is None:
return {}
results = {}
for feat_name, feat_metrics in self.metrics.items():
results[feat_name] = {}
for metric_name, metric in feat_metrics.items():
try:
results[feat_name][metric_name] = metric.compute().item()
except Exception:
results[feat_name][metric_name] = 0.0
return results
def compute_summary(self) -> Dict[str, float]:
"""Compute summary metrics averaged across features."""
if self.metrics is None:
return {}
per_feature = self.compute()
if not per_feature:
return {}
summary = {}
metric_names = ['accuracy', 'f1_macro', 'ap_macro']
for metric_name in metric_names:
values = [per_feature[f].get(metric_name, 0.0) for f in self.feature_names
if metric_name in per_feature.get(f, {})]
if values:
summary[f'avg_{metric_name}'] = sum(values) / len(values)
return summary