File size: 7,355 Bytes
09dd617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Unified benchmark for DPA experiments.

Modes:
  1. simulate — simulate DPA with attention masking on pretrained models (CPU/MPS OK)
  2. train — train DPA from scratch (needs GPU)
  3. finetune — finetune pretrained model with DPA wrapper (8xH100)
"""

import argparse
import json
import time
import torch
import torch.nn as nn
from pathlib import Path
from dataclasses import dataclass

import sys
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from src.models.baselines import build_model
from src.models.dpa_model import DPATransformer
from src.eval.metrics import compute_flops, compute_metrics


@dataclass
class BenchmarkResult:
    model_type: str
    accuracy: float
    perplexity: float
    decision_ratio: float
    flops_ratio: float  # relative to full transformer
    latency_ms: float
    kv_cache_mb: float
    num_params: int


def count_params(model):
    return sum(p.numel() for p in model.parameters())


def estimate_flops(model_type, seq_len, hidden_size, num_layers, num_heads, decision_ratio=1.0):
    """Estimate FLOPs for different model types."""
    # Full attention FLOPs per layer: 4 * seq_len^2 * hidden_size (QKV + output)
    full_attn_flops = 4 * seq_len * seq_len * hidden_size
    # Linear attention FLOPs per layer: 4 * seq_len * hidden_size * (hidden_size / num_heads)
    linear_attn_flops = 4 * seq_len * hidden_size * (hidden_size // num_heads)

    if model_type == "full_transformer":
        return full_attn_flops * num_layers
    elif model_type == "pure_linear":
        return linear_attn_flops * num_layers
    elif model_type == "uniform_hybrid":
        full_layers = num_layers // 4
        linear_layers = num_layers - full_layers
        return full_attn_flops * full_layers + linear_attn_flops * linear_layers
    elif model_type.startswith("dpa"):
        # DPA: decision_ratio of tokens use full attention
        dpa_flops_per_layer = (
            linear_attn_flops +  # linear for all tokens
            decision_ratio * full_attn_flops  # full only for decision points
        )
        return dpa_flops_per_layer * num_layers
    return 0


def run_simulation(args):
    """Run DPA simulation experiment (no training, attention masking only)."""
    print("=" * 60)
    print("DPA Simulation Experiment")
    print("=" * 60)

    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Device: {device}")

    model_types = ["full_transformer", "pure_linear", "uniform_hybrid", "dpa", "dpa_fixed"]
    ratios = [0.05, 0.10, 0.15, 0.25, 0.50]

    results = []

    # Common config
    cfg = dict(
        vocab_size=32000, hidden_size=args.hidden_size,
        num_layers=args.num_layers, num_heads=args.num_heads,
        max_seq_len=args.seq_len,
    )

    # Generate random data for simulation
    batch = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)
    labels = torch.randint(0, cfg["vocab_size"], (args.batch_size, args.seq_len), device=device)

    for model_type in model_types:
        if model_type in ("dpa", "dpa_fixed"):
            for ratio in ratios:
                model = build_model(model_type, target_ratio=ratio, **cfg).to(device)
                model.eval()

                with torch.no_grad():
                    t0 = time.time()
                    outputs = model(batch, labels=labels)
                    latency = (time.time() - t0) * 1000

                flops = estimate_flops(
                    model_type, args.seq_len, cfg["hidden_size"],
                    cfg["num_layers"], cfg["num_heads"], ratio,
                )
                full_flops = estimate_flops(
                    "full_transformer", args.seq_len, cfg["hidden_size"],
                    cfg["num_layers"], cfg["num_heads"],
                )

                result = BenchmarkResult(
                    model_type=f"{model_type}_r{ratio:.0%}",
                    accuracy=0.0,  # filled in real eval
                    perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
                    decision_ratio=outputs.get("avg_decision_ratio", ratio),
                    flops_ratio=flops / full_flops,
                    latency_ms=latency,
                    kv_cache_mb=0,
                    num_params=count_params(model),
                )
                results.append(result)
                print(f"  {result.model_type}: ppl={result.perplexity:.2f}, "
                      f"flops_ratio={result.flops_ratio:.2%}, "
                      f"latency={result.latency_ms:.1f}ms, "
                      f"params={result.num_params:,}")
                del model
        else:
            model = build_model(model_type, **cfg).to(device)
            model.eval()

            with torch.no_grad():
                t0 = time.time()
                outputs = model(batch, labels=labels)
                latency = (time.time() - t0) * 1000

            dr = outputs.get("avg_decision_ratio", 1.0 if model_type == "full_transformer" else 0.0)
            flops = estimate_flops(
                model_type, args.seq_len, cfg["hidden_size"],
                cfg["num_layers"], cfg["num_heads"], dr,
            )
            full_flops = estimate_flops(
                "full_transformer", args.seq_len, cfg["hidden_size"],
                cfg["num_layers"], cfg["num_heads"],
            )

            result = BenchmarkResult(
                model_type=model_type,
                accuracy=0.0,
                perplexity=torch.exp(outputs["loss"]).item() if outputs["loss"] else 0,
                decision_ratio=dr,
                flops_ratio=flops / full_flops,
                latency_ms=latency,
                kv_cache_mb=0,
                num_params=count_params(model),
            )
            results.append(result)
            print(f"  {result.model_type}: ppl={result.perplexity:.2f}, "
                  f"flops_ratio={result.flops_ratio:.2%}, "
                  f"latency={result.latency_ms:.1f}ms, "
                  f"params={result.num_params:,}")
            del model

    # Save results
    output_path = Path(args.output_dir) / "simulation_results.json"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump([vars(r) for r in results], f, indent=2)
    print(f"\nResults saved to {output_path}")

    return results


def main():
    parser = argparse.ArgumentParser(description="DPA Benchmark")
    parser.add_argument("--mode", choices=["simulate", "train", "finetune"], default="simulate")
    parser.add_argument("--hidden-size", type=int, default=512)
    parser.add_argument("--num-layers", type=int, default=6)
    parser.add_argument("--num-heads", type=int, default=8)
    parser.add_argument("--seq-len", type=int, default=1024)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--output-dir", type=str, default="results")
    args = parser.parse_args()

    if args.mode == "simulate":
        run_simulation(args)
    elif args.mode == "train":
        print("Training mode — requires GPU. Use scripts/run_dpa.sh on Merlin.")
    elif args.mode == "finetune":
        print("Finetune mode — requires 8xH100. Use scripts/run_dpa.sh on Merlin.")


if __name__ == "__main__":
    main()