""" Profile LexiMind training with PyTorch Profiler. Runs a few training steps under torch.profiler to capture: - CUDA kernel timing (per-operator breakdown) - GPU memory usage (peak allocations, memory timeline) - CPU/GPU overlap and idle time - Chrome trace (viewable in chrome://tracing or Perfetto UI) Outputs: outputs/profile/ -- Chrome trace + stacks stdout -- Summary table of top CUDA operations Usage: python scripts/profile_training.py # default: 20 steps python scripts/profile_training.py training=full # use full config PROFILE_STEPS=40 python scripts/profile_training.py # custom step count Author: Oliver Perrin """ from __future__ import annotations import os import sys from pathlib import Path import hydra import torch from omegaconf import DictConfig PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from src.data.dataloader import ( build_emotion_dataloader, build_summarization_dataloader, build_topic_dataloader, ) from src.data.dataset import ( EmotionDataset, SummarizationDataset, TopicDataset, load_emotion_jsonl, load_summarization_jsonl, load_topic_jsonl, ) from src.data.tokenization import Tokenizer, TokenizerConfig from src.models.factory import ModelConfig, build_multitask_model def load_splits(data_dir: Path, loader_fn): splits = {} for name, aliases in [("train", ["train"]), ("val", ["val", "validation"])]: for alias in aliases: path = data_dir / f"{alias}.jsonl" if path.exists(): splits[name] = loader_fn(str(path)) break return splits @hydra.main(version_base=None, config_path="../configs", config_name="config") def main(cfg: DictConfig) -> None: profile_steps = int(os.environ.get("PROFILE_STEPS", 20)) warmup_steps = 3 # let CUDA graphs / torch.compile settle active_steps = profile_steps - warmup_steps device = torch.device(cfg.device) if device.type != "cuda": print("Profiler requires CUDA. Set device=cuda.") return print(f"Profiling {profile_steps} steps ({warmup_steps} warmup + {active_steps} active)") print(f"GPU: {torch.cuda.get_device_name()}") # ---------- Setup (mirrors train.py) ---------- torch.backends.cudnn.benchmark = True if torch.cuda.get_device_capability()[0] >= 8: torch.set_float32_matmul_precision("high") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True data_cfg = cfg.data trainer_cfg = cfg.training.get("trainer", {}) # Load small subsets -- profiling doesn't need the full dataset max_samples = max(200, profile_steps * 10 * 3) summ_splits = load_splits(Path(data_cfg.processed.summarization), load_summarization_jsonl) emot_splits = load_splits(Path(data_cfg.processed.emotion), load_emotion_jsonl) topic_splits = load_splits(Path(data_cfg.processed.topic), load_topic_jsonl) for splits in [summ_splits, emot_splits, topic_splits]: splits["train"] = splits["train"][:max_samples] tok_cfg = data_cfg.get("tokenizer", {}) max_len = int(cfg.training.get("tokenizer_max_length") or tok_cfg.get("max_length", 512)) tokenizer = Tokenizer( TokenizerConfig( pretrained_model_name=tok_cfg.get("pretrained_model_name", "google/flan-t5-base"), max_length=max_len, ) ) summ_train = SummarizationDataset(summ_splits["train"]) emot_train = EmotionDataset(emot_splits["train"]) topic_train = TopicDataset(topic_splits["train"]) dl_cfg = cfg.training.get("dataloader", {}) batch_size = int(dl_cfg.get("batch_size", 8)) num_workers = int(dl_cfg.get("num_workers", 4)) classification_max_len = min(256, max_len) train_loaders = { "summarization": build_summarization_dataloader( summ_train, tokenizer, shuffle=True, max_source_length=max_len, max_target_length=max_len, batch_size=batch_size, num_workers=num_workers, pin_memory=True, ), "emotion": build_emotion_dataloader( emot_train, tokenizer, shuffle=True, max_length=classification_max_len, batch_size=batch_size, num_workers=num_workers, pin_memory=True, ), "topic": build_topic_dataloader( topic_train, tokenizer, shuffle=True, max_length=classification_max_len, batch_size=batch_size, num_workers=num_workers, pin_memory=True, ), } # Build model grad_ckpt = cfg.training.get( "gradient_checkpointing", cfg.model.get("gradient_checkpointing", False) ) use_rel_pos = cfg.training.get( "use_relative_position_bias", cfg.model.get("use_relative_position_bias", False) ) model_cfg = ModelConfig( d_model=cfg.model.d_model, vocab_size=getattr(cfg.model, "vocab_size", None), num_encoder_layers=cfg.model.num_encoder_layers, num_decoder_layers=cfg.model.num_decoder_layers, num_attention_heads=cfg.model.num_attention_heads, ffn_dim=cfg.model.ffn_dim, dropout=cfg.model.dropout, use_pretrained=cfg.model.use_pretrained, pretrained_model_name=cfg.model.pretrained_model_name, activation=getattr(cfg.model, "activation", "gelu"), use_relative_position_bias=use_rel_pos, gradient_checkpointing=grad_ckpt, ) model = build_multitask_model( tokenizer, num_emotions=len(emot_train.emotion_classes), num_topics=len(topic_train.topic_classes), config=model_cfg, ).to(device) # Freeze layers (same as train.py) freeze_layers = cfg.training.get("freeze_encoder_layers", 0) if freeze_layers > 0: if hasattr(model.encoder, "embed_tokens"): for p in model.encoder.embed_tokens.parameters(): p.requires_grad = False if hasattr(model.encoder, "layers"): for i, layer in enumerate(model.encoder.layers): if i < freeze_layers: for p in layer.parameters(): p.requires_grad = False # Compile (same as train.py) compile_mode = "default" if grad_ckpt else "reduce-overhead" if cfg.training.get("compile_encoder", True): model.encoder = torch.compile(model.encoder, mode=compile_mode) if cfg.training.get("compile_decoder", True): model.decoder = torch.compile(model.decoder, mode=compile_mode) # Optimizer opt_cfg = cfg.training.get("optimizer", {}) use_fused = "fused" in torch.optim.AdamW.__init__.__code__.co_varnames optimizer = torch.optim.AdamW( model.parameters(), lr=float(opt_cfg.get("lr", 3e-5)), weight_decay=float(opt_cfg.get("weight_decay", 0.01)), fused=use_fused, ) # ---------- Profile loop ---------- out_dir = PROJECT_ROOT / "outputs" / "profile" out_dir.mkdir(parents=True, exist_ok=True) model.train() iterators = {task: iter(loader) for task, loader in train_loaders.items()} task_names = list(train_loaders.keys()) accum = int(trainer_cfg.get("gradient_accumulation_steps", 4)) use_bf16 = torch.cuda.is_bf16_supported() task_weights = trainer_cfg.get("task_weights") or {} emotion_loss_fn = torch.nn.BCEWithLogitsLoss() topic_loss_fn = torch.nn.CrossEntropyLoss() def get_batch(task): try: batch = next(iterators[task]) except StopIteration: iterators[task] = iter(train_loaders[task]) batch = next(iterators[task]) return { k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } def training_step(step): """One training step across all tasks.""" for task in task_names: batch = get_batch(task) dtype = torch.bfloat16 if use_bf16 else torch.float16 with torch.autocast("cuda", dtype=dtype): if task == "summarization": inputs = {"src_ids": batch["src_ids"], "tgt_ids": batch["tgt_ids"]} if "src_mask" in batch: inputs["src_mask"] = batch["src_mask"] logits = model.forward("summarization", inputs) loss = torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), batch["labels"].view(-1), ignore_index=-100, label_smoothing=0.1, ) elif task == "emotion": inputs = {"input_ids": batch["input_ids"]} if "attention_mask" in batch: inputs["attention_mask"] = batch["attention_mask"] logits = model.forward("emotion", inputs) loss = emotion_loss_fn(logits, batch["labels"].float()) elif task == "topic": inputs = {"input_ids": batch["input_ids"]} if "attention_mask" in batch: inputs["attention_mask"] = batch["attention_mask"] logits = model.forward("topic", inputs) loss = topic_loss_fn(logits, batch["labels"]) else: continue weight = task_weights.get(task, 1.0) scaled = (loss * weight) / accum scaled.backward() if (step + 1) % accum == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() # Warmup outside profiler to let torch.compile finish print(f"\nWarmup ({warmup_steps} steps)...") for s in range(warmup_steps): training_step(s) optimizer.zero_grad() torch.cuda.synchronize() # Profile print(f"Profiling ({active_steps} steps)...") trace_path = str(out_dir / "trace") with torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( wait=1, warmup=2, active=active_steps - 3, repeat=1, ), on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), record_shapes=True, profile_memory=True, with_stack=True, with_flops=True, ) as prof: for s in range(active_steps): training_step(warmup_steps + s) prof.step() torch.cuda.synchronize() # ---------- Summary ---------- print("\n" + "=" * 80) print("TOP CUDA OPERATIONS (by total CUDA time)") print("=" * 80) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=25)) print("\n" + "=" * 80) print("TOP CUDA OPERATIONS (by GPU memory)") print("=" * 80) print(prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=15)) # Memory summary print("\n" + "=" * 80) print("GPU MEMORY SUMMARY") print("=" * 80) print(torch.cuda.memory_summary(abbreviated=True)) # Export Chrome trace chrome_trace = out_dir / "chrome_trace.json" prof.export_chrome_trace(str(chrome_trace)) print(f"\nChrome trace: {chrome_trace}") print(" Open in: chrome://tracing or https://ui.perfetto.dev") # Export stacks for flamegraph stacks_path = out_dir / "profiler_stacks.txt" prof.export_stacks(str(stacks_path), "self_cuda_time_total") print(f"CUDA stacks: {stacks_path}") print(f" Generate flamegraph: flamegraph.pl {stacks_path} > flamegraph.svg") print(f"\nTensorBoard traces: {trace_path}/") print(f" View with: tensorboard --logdir={trace_path}") if __name__ == "__main__": main()