| | """ |
| | Training profiler utilities for identifying bottlenecks. |
| | |
| | Uses PyTorch profiler to analyze training performance. |
| | """ |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional |
| | import torch |
| | from torch.profiler import ( |
| | ProfilerActivity, |
| | profile, |
| | record_function, |
| | schedule, |
| | tensorboard_trace_handler, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TrainingProfiler: |
| | """ |
| | Profiler for training loops. |
| | |
| | Identifies bottlenecks in forward pass, backward pass, and data loading. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | output_dir: Optional[Path] = None, |
| | activities: Optional[list] = None, |
| | record_shapes: bool = True, |
| | profile_memory: bool = True, |
| | with_stack: bool = False, |
| | ): |
| | """ |
| | Args: |
| | output_dir: Directory to save profiling results |
| | activities: Activities to profile (default: CUDA + CPU) |
| | record_shapes: Record tensor shapes |
| | profile_memory: Profile memory usage |
| | with_stack: Record stack traces |
| | """ |
| | self.output_dir = Path(output_dir) if output_dir else None |
| | if self.output_dir: |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | if activities is None: |
| | activities = [ProfilerActivity.CUDA, ProfilerActivity.CPU] |
| |
|
| | self.activities = activities |
| | self.record_shapes = record_shapes |
| | self.profile_memory = profile_memory |
| | self.with_stack = with_stack |
| |
|
| | self.profiler = None |
| | self.trace_handler = None |
| |
|
| | if self.output_dir: |
| | self.trace_handler = tensorboard_trace_handler(str(self.output_dir)) |
| |
|
| | def start(self): |
| | """Start profiling.""" |
| | schedule_fn = schedule( |
| | wait=1, |
| | warmup=1, |
| | active=3, |
| | repeat=2, |
| | ) |
| |
|
| | self.profiler = profile( |
| | activities=self.activities, |
| | schedule=schedule_fn, |
| | record_shapes=self.record_shapes, |
| | profile_memory=self.profile_memory, |
| | with_stack=self.with_stack, |
| | on_trace_ready=self.trace_handler, |
| | ) |
| |
|
| | self.profiler.start() |
| | logger.info("Profiling started") |
| |
|
| | def stop(self): |
| | """Stop profiling and generate report.""" |
| | if self.profiler is None: |
| | return |
| |
|
| | self.profiler.stop() |
| |
|
| | |
| | if self.output_dir: |
| | summary_path = self.output_dir / "profiler_summary.txt" |
| | with open(summary_path, "w") as f: |
| | f.write( |
| | self.profiler.key_averages().table( |
| | sort_by=( |
| | "cuda_time_total" if torch.cuda.is_available() else "cpu_time_total" |
| | ), |
| | row_limit=100, |
| | ) |
| | ) |
| | logger.info(f"Profiler summary saved to {summary_path}") |
| |
|
| | logger.info("Profiling stopped") |
| |
|
| | def step(self): |
| | """Step profiler (call at each training step).""" |
| | if self.profiler: |
| | self.profiler.step() |
| |
|
| | def __enter__(self): |
| | self.start() |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | self.stop() |
| |
|
| |
|
| | def profile_training_step( |
| | model: torch.nn.Module, |
| | loss_fn: callable, |
| | optimizer: torch.optim.Optimizer, |
| | sample_batch: Dict, |
| | device: str = "cuda", |
| | output_dir: Optional[Path] = None, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Profile a single training step. |
| | |
| | Args: |
| | model: Model to profile |
| | loss_fn: Loss function |
| | optimizer: Optimizer |
| | sample_batch: Sample batch of data |
| | device: Device to run on |
| | output_dir: Directory to save results |
| | |
| | Returns: |
| | Dict with profiling results |
| | """ |
| | activities = [ProfilerActivity.CPU] |
| | if device == "cuda" and torch.cuda.is_available(): |
| | activities.append(ProfilerActivity.CUDA) |
| |
|
| | with profile( |
| | activities=activities, |
| | record_shapes=True, |
| | profile_memory=True, |
| | with_stack=True, |
| | ) as prof: |
| | with record_function("forward"): |
| | |
| | output = model(sample_batch["images"].to(device)) |
| | loss = loss_fn(output, sample_batch["targets"].to(device)) |
| |
|
| | with record_function("backward"): |
| | |
| | loss.backward() |
| |
|
| | with record_function("optimizer_step"): |
| | |
| | optimizer.step() |
| | optimizer.zero_grad() |
| |
|
| | |
| | results = { |
| | "forward_time_ms": 0, |
| | "backward_time_ms": 0, |
| | "optimizer_time_ms": 0, |
| | "total_time_ms": 0, |
| | "memory_allocated_mb": 0, |
| | "memory_reserved_mb": 0, |
| | } |
| |
|
| | |
| | key_averages = prof.key_averages() |
| | for event in key_averages: |
| | if "forward" in event.key: |
| | results["forward_time_ms"] += ( |
| | event.cuda_time_total if device == "cuda" else event.cpu_time_total |
| | ) |
| | elif "backward" in event.key: |
| | results["backward_time_ms"] += ( |
| | event.cuda_time_total if device == "cuda" else event.cpu_time_total |
| | ) |
| | elif "optimizer" in event.key: |
| | results["optimizer_time_ms"] += ( |
| | event.cuda_time_total if device == "cuda" else event.cpu_time_total |
| | ) |
| |
|
| | |
| | if device == "cuda": |
| | results["forward_time_ms"] /= 1000 |
| | results["backward_time_ms"] /= 1000 |
| | results["optimizer_time_ms"] /= 1000 |
| |
|
| | results["total_time_ms"] = ( |
| | results["forward_time_ms"] + results["backward_time_ms"] + results["optimizer_time_ms"] |
| | ) |
| |
|
| | |
| | if device == "cuda" and torch.cuda.is_available(): |
| | results["memory_allocated_mb"] = torch.cuda.memory_allocated() / 1024 / 1024 |
| | results["memory_reserved_mb"] = torch.cuda.memory_reserved() / 1024 / 1024 |
| |
|
| | |
| | if output_dir: |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | table_path = output_dir / "profiler_table.txt" |
| | with open(table_path, "w") as f: |
| | f.write( |
| | prof.key_averages().table( |
| | sort_by="cuda_time_total" if device == "cuda" else "cpu_time_total", |
| | row_limit=50, |
| | ) |
| | ) |
| |
|
| | logger.info(f"Profiling results saved to {output_dir}") |
| |
|
| | return results |
| |
|
| |
|
| | def analyze_bottlenecks(profiler_output: str) -> Dict[str, Any]: |
| | """ |
| | Analyze profiler output to identify bottlenecks. |
| | |
| | Args: |
| | profiler_output: Profiler table output as string |
| | |
| | Returns: |
| | Dict with bottleneck analysis |
| | """ |
| | lines = profiler_output.split("\n") |
| |
|
| | bottlenecks = { |
| | "slowest_operations": [], |
| | "memory_hotspots": [], |
| | "recommendations": [], |
| | } |
| |
|
| | |
| | for line in lines: |
| | if "forward" in line.lower() and "backward" not in line.lower(): |
| | bottlenecks["recommendations"].append( |
| | "Consider gradient checkpointing for forward pass" |
| | ) |
| | if "data_loader" in line.lower() or "dataloader" in line.lower(): |
| | bottlenecks["recommendations"].append( |
| | "Data loading may be a bottleneck - increase num_workers" |
| | ) |
| | if "memory" in line.lower() and "high" in line.lower(): |
| | bottlenecks["recommendations"].append( |
| | "High memory usage - consider gradient checkpointing or smaller batch size" |
| | ) |
| |
|
| | return bottlenecks |
| |
|