| """ |
| Unified benchmark for DPA experiments. |
| |
| Modes: |
| 1. simulate — simulate DPA with attention masking on pretrained models (CPU/MPS OK) |
| 2. train — train DPA from scratch (needs GPU) |
| 3. finetune — finetune pretrained model with DPA wrapper (8xH100) |
| """ |
|
|
| import argparse |
| import json |
| import time |
| import torch |
| import torch.nn as nn |
| from pathlib import Path |
| from dataclasses import dataclass |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
| from src.models.baselines import build_model |
| from src.models.dpa_model import DPATransformer |
| from src.eval.metrics import compute_flops, compute_metrics |
|
|
|
|
| @dataclass |
| class BenchmarkResult: |
| model_type: str |
| accuracy: float |
| perplexity: float |
| decision_ratio: float |
| flops_ratio: float |
| latency_ms: float |
| kv_cache_mb: float |
| num_params: int |
|
|
|
|
| def count_params(model): |
| return sum(p.numel() for p in model.parameters()) |
|
|
|
|
| def estimate_flops(model_type, seq_len, hidden_size, num_layers, num_heads, decision_ratio=1.0): |
| """Estimate FLOPs for different model types.""" |
| |
| full_attn_flops = 4 * seq_len * seq_len * hidden_size |
| |
| linear_attn_flops = 4 * seq_len * hidden_size * (hidden_size // num_heads) |
|
|
| if model_type == "full_transformer": |
| return full_attn_flops * num_layers |
| elif model_type == "pure_linear": |
| return linear_attn_flops * num_layers |
| elif model_type == "uniform_hybrid": |
| full_layers = num_layers // 4 |
| linear_layers = num_layers - full_layers |
| return full_attn_flops * full_layers + linear_attn_flops * linear_layers |
| elif model_type.startswith("dpa"): |
| |
| dpa_flops_per_layer = ( |
| linear_attn_flops + |
| decision_ratio * full_attn_flops |
| ) |
| return dpa_flops_per_layer * num_layers |
| return 0 |
|
|
|
|
| def run_simulation(args): |
| """Run DPA simulation experiment (no training, attention masking only).""" |
| print("=" * 60) |
| print("DPA Simulation Experiment") |
| print("=" * 60) |
|
|
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"Device: {device}") |
|
|
| model_types = ["full_transformer", "pure_linear", "uniform_hybrid", "dpa", "dpa_fixed"] |
| ratios = [0.05, 0.10, 0.15, 0.25, 0.50] |
|
|
| results = [] |
|
|
| |
| cfg = dict( |
| vocab_size=32000, hidden_size=args.hidden_size, |
| num_layers=args.num_layers, num_heads=args.num_heads, |
| max_seq_len=args.seq_len, |
| ) |
|
|
| |
| batch = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device) |
| labels = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device) |
|
|
| for model_type in model_types: |
| if model_type in ("dpa", "dpa_fixed"): |
| for ratio in ratios: |
| model = build_model(model_type, target_ratio=ratio, **cfg).to(device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| t0 = time.time() |
| outputs = model(batch, labels=labels) |
| latency = (time.time() - t0) * 1000 |
|
|
| flops = estimate_flops( |
| model_type, args.seq_len, cfg["hidden_size"], |
| cfg["num_layers"], cfg["num_heads"], ratio, |
| ) |
| full_flops = estimate_flops( |
| "full_transformer", args.seq_len, cfg["hidden_size"], |
| cfg["num_layers"], cfg["num_heads"], |
| ) |
|
|
| result = BenchmarkResult( |
| model_type=f"{model_type}_r{ratio:.0%}", |
| accuracy=0.0, |
| perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0, |
| decision_ratio=outputs.get("avg_decision_ratio", ratio), |
| flops_ratio=flops / full_flops, |
| latency_ms=latency, |
| kv_cache_mb=0, |
| num_params=count_params(model), |
| ) |
| results.append(result) |
| print(f" {result.model_type}: ppl={result.perplexity:.2f}, " |
| f"flops_ratio={result.flops_ratio:.2%}, " |
| f"latency={result.latency_ms:.1f}ms, " |
| f"params={result.num_params:,}") |
| del model |
| else: |
| model = build_model(model_type, **cfg).to(device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| t0 = time.time() |
| outputs = model(batch, labels=labels) |
| latency = (time.time() - t0) * 1000 |
|
|
| dr = outputs.get("avg_decision_ratio", 1.0 if model_type == "full_transformer" else 0.0) |
| flops = estimate_flops( |
| model_type, args.seq_len, cfg["hidden_size"], |
| cfg["num_layers"], cfg["num_heads"], dr, |
| ) |
| full_flops = estimate_flops( |
| "full_transformer", args.seq_len, cfg["hidden_size"], |
| cfg["num_layers"], cfg["num_heads"], |
| ) |
|
|
| result = BenchmarkResult( |
| model_type=model_type, |
| accuracy=0.0, |
| perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0, |
| decision_ratio=dr, |
| flops_ratio=flops / full_flops, |
| latency_ms=latency, |
| kv_cache_mb=0, |
| num_params=count_params(model), |
| ) |
| results.append(result) |
| print(f" {result.model_type}: ppl={result.perplexity:.2f}, " |
| f"flops_ratio={result.flops_ratio:.2%}, " |
| f"latency={result.latency_ms:.1f}ms, " |
| f"params={result.num_params:,}") |
| del model |
|
|
| |
| output_path = Path(args.output_dir) / "simulation_results.json" |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump([vars(r) for r in results], f, indent=2) |
| print(f"\nResults saved to {output_path}") |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="DPA Benchmark") |
| parser.add_argument("--mode", choices=["simulate", "train", "finetune"], default="simulate") |
| parser.add_argument("--hidden-size", type=int, default=512) |
| parser.add_argument("--num-layers", type=int, default=6) |
| parser.add_argument("--num-heads", type=int, default=8) |
| parser.add_argument("--seq-len", type=int, default=1024) |
| parser.add_argument("--batch-size", type=int, default=4) |
| parser.add_argument("--output-dir", type=str, default="results") |
| args = parser.parse_args() |
|
|
| if args.mode == "simulate": |
| run_simulation(args) |
| elif args.mode == "train": |
| print("Training mode — requires GPU. Use scripts/run_dpa.sh on Merlin.") |
| elif args.mode == "finetune": |
| print("Finetune mode — requires 8xH100. Use scripts/run_dpa.sh on Merlin.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|