File size: 7,355 Bytes
09dd617 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """
Unified benchmark for DPA experiments.
Modes:
1. simulate — simulate DPA with attention masking on pretrained models (CPU/MPS OK)
2. train — train DPA from scratch (needs GPU)
3. finetune — finetune pretrained model with DPA wrapper (8xH100)
"""
import argparse
import json
import time
import torch
import torch.nn as nn
from pathlib import Path
from dataclasses import dataclass
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.models.baselines import build_model
from src.models.dpa_model import DPATransformer
from src.eval.metrics import compute_flops, compute_metrics
@dataclass
class BenchmarkResult:
model_type: str
accuracy: float
perplexity: float
decision_ratio: float
flops_ratio: float # relative to full transformer
latency_ms: float
kv_cache_mb: float
num_params: int
def count_params(model):
return sum(p.numel() for p in model.parameters())
def estimate_flops(model_type, seq_len, hidden_size, num_layers, num_heads, decision_ratio=1.0):
"""Estimate FLOPs for different model types."""
# Full attention FLOPs per layer: 4 * seq_len^2 * hidden_size (QKV + output)
full_attn_flops = 4 * seq_len * seq_len * hidden_size
# Linear attention FLOPs per layer: 4 * seq_len * hidden_size * (hidden_size / num_heads)
linear_attn_flops = 4 * seq_len * hidden_size * (hidden_size // num_heads)
if model_type == "full_transformer":
return full_attn_flops * num_layers
elif model_type == "pure_linear":
return linear_attn_flops * num_layers
elif model_type == "uniform_hybrid":
full_layers = num_layers // 4
linear_layers = num_layers - full_layers
return full_attn_flops * full_layers + linear_attn_flops * linear_layers
elif model_type.startswith("dpa"):
# DPA: decision_ratio of tokens use full attention
dpa_flops_per_layer = (
linear_attn_flops + # linear for all tokens
decision_ratio * full_attn_flops # full only for decision points
)
return dpa_flops_per_layer * num_layers
return 0
def run_simulation(args):
"""Run DPA simulation experiment (no training, attention masking only)."""
print("=" * 60)
print("DPA Simulation Experiment")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")
model_types = ["full_transformer", "pure_linear", "uniform_hybrid", "dpa", "dpa_fixed"]
ratios = [0.05, 0.10, 0.15, 0.25, 0.50]
results = []
# Common config
cfg = dict(
vocab_size=32000, hidden_size=args.hidden_size,
num_layers=args.num_layers, num_heads=args.num_heads,
max_seq_len=args.seq_len,
)
# Generate random data for simulation
batch = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)
labels = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)
for model_type in model_types:
if model_type in ("dpa", "dpa_fixed"):
for ratio in ratios:
model = build_model(model_type, target_ratio=ratio, **cfg).to(device)
model.eval()
with torch.no_grad():
t0 = time.time()
outputs = model(batch, labels=labels)
latency = (time.time() - t0) * 1000
flops = estimate_flops(
model_type, args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"], ratio,
)
full_flops = estimate_flops(
"full_transformer", args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"],
)
result = BenchmarkResult(
model_type=f"{model_type}_r{ratio:.0%}",
accuracy=0.0, # filled in real eval
perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
decision_ratio=outputs.get("avg_decision_ratio", ratio),
flops_ratio=flops / full_flops,
latency_ms=latency,
kv_cache_mb=0,
num_params=count_params(model),
)
results.append(result)
print(f" {result.model_type}: ppl={result.perplexity:.2f}, "
f"flops_ratio={result.flops_ratio:.2%}, "
f"latency={result.latency_ms:.1f}ms, "
f"params={result.num_params:,}")
del model
else:
model = build_model(model_type, **cfg).to(device)
model.eval()
with torch.no_grad():
t0 = time.time()
outputs = model(batch, labels=labels)
latency = (time.time() - t0) * 1000
dr = outputs.get("avg_decision_ratio", 1.0 if model_type == "full_transformer" else 0.0)
flops = estimate_flops(
model_type, args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"], dr,
)
full_flops = estimate_flops(
"full_transformer", args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"],
)
result = BenchmarkResult(
model_type=model_type,
accuracy=0.0,
perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
decision_ratio=dr,
flops_ratio=flops / full_flops,
latency_ms=latency,
kv_cache_mb=0,
num_params=count_params(model),
)
results.append(result)
print(f" {result.model_type}: ppl={result.perplexity:.2f}, "
f"flops_ratio={result.flops_ratio:.2%}, "
f"latency={result.latency_ms:.1f}ms, "
f"params={result.num_params:,}")
del model
# Save results
output_path = Path(args.output_dir) / "simulation_results.json"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump([vars(r) for r in results], f, indent=2)
print(f"\nResults saved to {output_path}")
return results
def main():
parser = argparse.ArgumentParser(description="DPA Benchmark")
parser.add_argument("--mode", choices=["simulate", "train", "finetune"], default="simulate")
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--num-layers", type=int, default=6)
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--output-dir", type=str, default="results")
args = parser.parse_args()
if args.mode == "simulate":
run_simulation(args)
elif args.mode == "train":
print("Training mode — requires GPU. Use scripts/run_dpa.sh on Merlin.")
elif args.mode == "finetune":
print("Finetune mode — requires 8xH100. Use scripts/run_dpa.sh on Merlin.")
if __name__ == "__main__":
main()
|