""" Unified benchmark for DPA experiments. Modes: 1. simulate — simulate DPA with attention masking on pretrained models (CPU/MPS OK) 2. train — train DPA from scratch (needs GPU) 3. finetune — finetune pretrained model with DPA wrapper (8xH100) """ import argparse import json import time import torch import torch.nn as nn from pathlib import Path from dataclasses import dataclass import sys sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from src.models.baselines import build_model from src.models.dpa_model import DPATransformer from src.eval.metrics import compute_flops, compute_metrics @dataclass class BenchmarkResult: model_type: str accuracy: float perplexity: float decision_ratio: float flops_ratio: float # relative to full transformer latency_ms: float kv_cache_mb: float num_params: int def count_params(model): return sum(p.numel() for p in model.parameters()) def estimate_flops(model_type, seq_len, hidden_size, num_layers, num_heads, decision_ratio=1.0): """Estimate FLOPs for different model types.""" # Full attention FLOPs per layer: 4 * seq_len^2 * hidden_size (QKV + output) full_attn_flops = 4 * seq_len * seq_len * hidden_size # Linear attention FLOPs per layer: 4 * seq_len * hidden_size * (hidden_size / num_heads) linear_attn_flops = 4 * seq_len * hidden_size * (hidden_size // num_heads) if model_type == "full_transformer": return full_attn_flops * num_layers elif model_type == "pure_linear": return linear_attn_flops * num_layers elif model_type == "uniform_hybrid": full_layers = num_layers // 4 linear_layers = num_layers - full_layers return full_attn_flops * full_layers + linear_attn_flops * linear_layers elif model_type.startswith("dpa"): # DPA: decision_ratio of tokens use full attention dpa_flops_per_layer = ( linear_attn_flops + # linear for all tokens decision_ratio * full_attn_flops # full only for decision points ) return dpa_flops_per_layer * num_layers return 0 def run_simulation(args): """Run DPA simulation experiment (no training, attention masking only).""" print("=" * 60) print("DPA Simulation Experiment") print("=" * 60) device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Device: {device}") model_types = ["full_transformer", "pure_linear", "uniform_hybrid", "dpa", "dpa_fixed"] ratios = [0.05, 0.10, 0.15, 0.25, 0.50] results = [] # Common config cfg = dict( vocab_size=32000, hidden_size=args.hidden_size, num_layers=args.num_layers, num_heads=args.num_heads, max_seq_len=args.seq_len, ) # Generate random data for simulation batch = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device) labels = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device) for model_type in model_types: if model_type in ("dpa", "dpa_fixed"): for ratio in ratios: model = build_model(model_type, target_ratio=ratio, **cfg).to(device) model.eval() with torch.no_grad(): t0 = time.time() outputs = model(batch, labels=labels) latency = (time.time() - t0) * 1000 flops = estimate_flops( model_type, args.seq_len, cfg["hidden_size"], cfg["num_layers"], cfg["num_heads"], ratio, ) full_flops = estimate_flops( "full_transformer", args.seq_len, cfg["hidden_size"], cfg["num_layers"], cfg["num_heads"], ) result = BenchmarkResult( model_type=f"{model_type}_r{ratio:.0%}", accuracy=0.0, # filled in real eval perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0, decision_ratio=outputs.get("avg_decision_ratio", ratio), flops_ratio=flops / full_flops, latency_ms=latency, kv_cache_mb=0, num_params=count_params(model), ) results.append(result) print(f" {result.model_type}: ppl={result.perplexity:.2f}, " f"flops_ratio={result.flops_ratio:.2%}, " f"latency={result.latency_ms:.1f}ms, " f"params={result.num_params:,}") del model else: model = build_model(model_type, **cfg).to(device) model.eval() with torch.no_grad(): t0 = time.time() outputs = model(batch, labels=labels) latency = (time.time() - t0) * 1000 dr = outputs.get("avg_decision_ratio", 1.0 if model_type == "full_transformer" else 0.0) flops = estimate_flops( model_type, args.seq_len, cfg["hidden_size"], cfg["num_layers"], cfg["num_heads"], dr, ) full_flops = estimate_flops( "full_transformer", args.seq_len, cfg["hidden_size"], cfg["num_layers"], cfg["num_heads"], ) result = BenchmarkResult( model_type=model_type, accuracy=0.0, perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0, decision_ratio=dr, flops_ratio=flops / full_flops, latency_ms=latency, kv_cache_mb=0, num_params=count_params(model), ) results.append(result) print(f" {result.model_type}: ppl={result.perplexity:.2f}, " f"flops_ratio={result.flops_ratio:.2%}, " f"latency={result.latency_ms:.1f}ms, " f"params={result.num_params:,}") del model # Save results output_path = Path(args.output_dir) / "simulation_results.json" output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: json.dump([vars(r) for r in results], f, indent=2) print(f"\nResults saved to {output_path}") return results def main(): parser = argparse.ArgumentParser(description="DPA Benchmark") parser.add_argument("--mode", choices=["simulate", "train", "finetune"], default="simulate") parser.add_argument("--hidden-size", type=int, default=512) parser.add_argument("--num-layers", type=int, default=6) parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--seq-len", type=int, default=1024) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--output-dir", type=str, default="results") args = parser.parse_args() if args.mode == "simulate": run_simulation(args) elif args.mode == "train": print("Training mode — requires GPU. Use scripts/run_dpa.sh on Merlin.") elif args.mode == "finetune": print("Finetune mode — requires 8xH100. Use scripts/run_dpa.sh on Merlin.") if __name__ == "__main__": main()