jasonfan's picture
Upload folder using huggingface_hub
09dd617 verified
"""
Unified benchmark for DPA experiments.
Modes:
1. simulate — simulate DPA with attention masking on pretrained models (CPU/MPS OK)
2. train — train DPA from scratch (needs GPU)
3. finetune — finetune pretrained model with DPA wrapper (8xH100)
"""
import argparse
import json
import time
import torch
import torch.nn as nn
from pathlib import Path
from dataclasses import dataclass
import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.models.baselines import build_model
from src.models.dpa_model import DPATransformer
from src.eval.metrics import compute_flops, compute_metrics
@dataclass
class BenchmarkResult:
model_type: str
accuracy: float
perplexity: float
decision_ratio: float
flops_ratio: float # relative to full transformer
latency_ms: float
kv_cache_mb: float
num_params: int
def count_params(model):
return sum(p.numel() for p in model.parameters())
def estimate_flops(model_type, seq_len, hidden_size, num_layers, num_heads, decision_ratio=1.0):
"""Estimate FLOPs for different model types."""
# Full attention FLOPs per layer: 4 * seq_len^2 * hidden_size (QKV + output)
full_attn_flops = 4 * seq_len * seq_len * hidden_size
# Linear attention FLOPs per layer: 4 * seq_len * hidden_size * (hidden_size / num_heads)
linear_attn_flops = 4 * seq_len * hidden_size * (hidden_size // num_heads)
if model_type == "full_transformer":
return full_attn_flops * num_layers
elif model_type == "pure_linear":
return linear_attn_flops * num_layers
elif model_type == "uniform_hybrid":
full_layers = num_layers // 4
linear_layers = num_layers - full_layers
return full_attn_flops * full_layers + linear_attn_flops * linear_layers
elif model_type.startswith("dpa"):
# DPA: decision_ratio of tokens use full attention
dpa_flops_per_layer = (
linear_attn_flops + # linear for all tokens
decision_ratio * full_attn_flops # full only for decision points
)
return dpa_flops_per_layer * num_layers
return 0
def run_simulation(args):
"""Run DPA simulation experiment (no training, attention masking only)."""
print("=" * 60)
print("DPA Simulation Experiment")
print("=" * 60)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")
model_types = ["full_transformer", "pure_linear", "uniform_hybrid", "dpa", "dpa_fixed"]
ratios = [0.05, 0.10, 0.15, 0.25, 0.50]
results = []
# Common config
cfg = dict(
vocab_size=32000, hidden_size=args.hidden_size,
num_layers=args.num_layers, num_heads=args.num_heads,
max_seq_len=args.seq_len,
)
# Generate random data for simulation
batch = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)
labels = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)
for model_type in model_types:
if model_type in ("dpa", "dpa_fixed"):
for ratio in ratios:
model = build_model(model_type, target_ratio=ratio, **cfg).to(device)
model.eval()
with torch.no_grad():
t0 = time.time()
outputs = model(batch, labels=labels)
latency = (time.time() - t0) * 1000
flops = estimate_flops(
model_type, args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"], ratio,
)
full_flops = estimate_flops(
"full_transformer", args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"],
)
result = BenchmarkResult(
model_type=f"{model_type}_r{ratio:.0%}",
accuracy=0.0, # filled in real eval
perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
decision_ratio=outputs.get("avg_decision_ratio", ratio),
flops_ratio=flops / full_flops,
latency_ms=latency,
kv_cache_mb=0,
num_params=count_params(model),
)
results.append(result)
print(f" {result.model_type}: ppl={result.perplexity:.2f}, "
f"flops_ratio={result.flops_ratio:.2%}, "
f"latency={result.latency_ms:.1f}ms, "
f"params={result.num_params:,}")
del model
else:
model = build_model(model_type, **cfg).to(device)
model.eval()
with torch.no_grad():
t0 = time.time()
outputs = model(batch, labels=labels)
latency = (time.time() - t0) * 1000
dr = outputs.get("avg_decision_ratio", 1.0 if model_type == "full_transformer" else 0.0)
flops = estimate_flops(
model_type, args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"], dr,
)
full_flops = estimate_flops(
"full_transformer", args.seq_len, cfg["hidden_size"],
cfg["num_layers"], cfg["num_heads"],
)
result = BenchmarkResult(
model_type=model_type,
accuracy=0.0,
perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
decision_ratio=dr,
flops_ratio=flops / full_flops,
latency_ms=latency,
kv_cache_mb=0,
num_params=count_params(model),
)
results.append(result)
print(f" {result.model_type}: ppl={result.perplexity:.2f}, "
f"flops_ratio={result.flops_ratio:.2%}, "
f"latency={result.latency_ms:.1f}ms, "
f"params={result.num_params:,}")
del model
# Save results
output_path = Path(args.output_dir) / "simulation_results.json"
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump([vars(r) for r in results], f, indent=2)
print(f"\nResults saved to {output_path}")
return results
def main():
parser = argparse.ArgumentParser(description="DPA Benchmark")
parser.add_argument("--mode", choices=["simulate", "train", "finetune"], default="simulate")
parser.add_argument("--hidden-size", type=int, default=512)
parser.add_argument("--num-layers", type=int, default=6)
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--output-dir", type=str, default="results")
args = parser.parse_args()
if args.mode == "simulate":
run_simulation(args)
elif args.mode == "train":
print("Training mode — requires GPU. Use scripts/run_dpa.sh on Merlin.")
elif args.mode == "finetune":
print("Finetune mode — requires 8xH100. Use scripts/run_dpa.sh on Merlin.")
if __name__ == "__main__":
main()