| """Timing and profiling utilities for training.""" |
|
|
| import time |
| from dataclasses import dataclass |
| from typing import Dict |
|
|
|
|
| @dataclass |
| class TimingStats: |
| """Container for timing statistics.""" |
| train_time: float = 0.0 |
| val_time: float = 0.0 |
| total_epoch_time: float = 0.0 |
|
|
| def get_summary_dict(self) -> Dict[str, float]: |
| """Get all timing stats as dictionary.""" |
| return { |
| 'train_time': self.train_time, |
| 'val_time': self.val_time, |
| 'total_epoch_time': self.total_epoch_time |
| } |
|
|
|
|
| class TimingProfiler: |
| """Profiler for tracking detailed operation timings.""" |
|
|
| def __init__(self): |
| self.timers = {} |
| self.accumulated = {} |
| self.counts = {} |
|
|
| def start(self, name: str): |
| """Start timing an operation.""" |
| self.timers[name] = time.time() |
|
|
| def end(self, name: str) -> float: |
| """End timing and return elapsed time.""" |
| if name not in self.timers: |
| return 0.0 |
| elapsed = time.time() - self.timers[name] |
|
|
| |
| if name not in self.accumulated: |
| self.accumulated[name] = 0.0 |
| self.counts[name] = 0 |
| self.accumulated[name] += elapsed |
| self.counts[name] += 1 |
|
|
| return elapsed |
|
|
| def get_average(self, name: str) -> float: |
| """Get average time for an operation.""" |
| if name not in self.accumulated or self.counts[name] == 0: |
| return 0.0 |
| return self.accumulated[name] / self.counts[name] |
|
|
| def reset(self): |
| """Reset all timers.""" |
| self.timers.clear() |
| self.accumulated.clear() |
| self.counts.clear() |
|
|
| def get_summary(self) -> Dict[str, float]: |
| """Get summary of all timings.""" |
| summary = {} |
| for name in self.accumulated: |
| summary[f"{name}_total"] = self.accumulated[name] |
| summary[f"{name}_avg"] = self.get_average(name) |
| summary[f"{name}_count"] = self.counts[name] |
| return summary |
|
|
|
|
| def format_time(seconds: float) -> str: |
| """Format seconds to 'X min Y sec' format.""" |
| if seconds < 60: |
| return f"{seconds:.1f} sec" |
| minutes = int(seconds // 60) |
| secs = seconds % 60 |
| return f"{minutes} min {secs:.1f} sec" |