File size: 7,189 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Benchmark Generation — Prefill + Decode across Goliath configs
==============================================================

Measures:
  - Prefill throughput (tok/s)
  - Decode throughput (tok/s) and per-token latency (ms)
  - VRAM usage (GB)

Configs tested:
  - Goliath FP4   (goliath_bits=4)
  - Goliath FP8   (goliath_bits=8)
  - Goliath Auto   (goliath_bits='auto')
  - Legacy path    (use_goliath=False)

Context lengths: 512, 2048, 8192

Usage:
    python3 benchmark_generation.py
"""

import gc
import sys
import time
import torch

# Ensure the kernel directory is importable
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from fireecho_kernel import FireEchoConfig, FireEchoEngine, _GOLIATH_AVAILABLE

if _GOLIATH_AVAILABLE:
    from goliath_kernel import _can_use_goliath_dot_scaled
else:
    _can_use_goliath_dot_scaled = None


# ============================================================================
# Engine Factory
# ============================================================================

def create_bench_engine(goliath_bits=4, use_goliath=True, num_layers=8):
    """Create a 7B-scale engine (reduced layers) with the given Goliath config."""
    config = FireEchoConfig(
        dim=4096,
        num_heads=32,
        num_kv_heads=8,
        num_layers=num_layers,
        vocab_size=32000,
        intermediate_size=11008,
        max_seq_len=16384,
        max_kv_blocks=1024,
        use_nvfp4=True,
        quantize_weights=True,
        goliath_bits=goliath_bits,
        use_goliath=use_goliath,
        use_hebbian=False,
        use_vision=False,
        use_audio=False,
    )
    engine = FireEchoEngine(config).cuda()
    engine.eval()
    return engine


# ============================================================================
# Benchmark Helpers
# ============================================================================

def bench_prefill(engine, seq_len, warmup=3, iters=5):
    """Benchmark prefill (forward pass on full prompt).

    Returns dict with ms, tok_s, vram_gb.
    """
    input_ids = torch.randint(0, 32000, (1, seq_len), device='cuda')

    # Warmup
    for _ in range(warmup):
        engine.reset_cache()
        with torch.no_grad():
            _ = engine(input_ids, use_cache=False)
    torch.cuda.synchronize()

    # Benchmark
    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)

    start_evt.record()
    for _ in range(iters):
        engine.reset_cache()
        with torch.no_grad():
            _ = engine(input_ids, use_cache=False)
    end_evt.record()
    torch.cuda.synchronize()

    ms = start_evt.elapsed_time(end_evt) / iters
    tok_s = seq_len / (ms / 1000.0)
    vram_gb = torch.cuda.memory_allocated() / 1e9

    return {'ms': ms, 'tok_s': tok_s, 'vram_gb': vram_gb}


def bench_decode(engine, prompt_len, num_decode_tokens=50, warmup=2):
    """Benchmark decode (token-by-token generation after prefill).

    Returns dict with total_ms, per_token_ms, tok_s, vram_gb.
    """
    prompt = torch.randint(0, 32000, (1, prompt_len), device='cuda')

    # Warmup
    for _ in range(warmup):
        engine.reset_cache()
        with torch.no_grad():
            _ = engine.generate(prompt, max_new_tokens=5, use_cache=False)
    torch.cuda.synchronize()

    gc.collect()
    torch.cuda.empty_cache()

    # Benchmark
    start_evt = torch.cuda.Event(enable_timing=True)
    end_evt = torch.cuda.Event(enable_timing=True)

    engine.reset_cache()
    start_evt.record()
    with torch.no_grad():
        output = engine.generate(prompt, max_new_tokens=num_decode_tokens,
                                 use_cache=False)
    end_evt.record()
    torch.cuda.synchronize()

    gen_tokens = output.shape[1] - prompt_len
    total_ms = start_evt.elapsed_time(end_evt)
    per_token_ms = total_ms / max(gen_tokens, 1)
    tok_s = gen_tokens / (total_ms / 1000.0) if total_ms > 0 else 0.0
    vram_gb = torch.cuda.memory_allocated() / 1e9

    return {
        'total_ms': total_ms,
        'per_token_ms': per_token_ms,
        'tok_s': tok_s,
        'vram_gb': vram_gb,
        'gen_tokens': gen_tokens,
    }


# ============================================================================
# Main Benchmark
# ============================================================================

def main():
    if not torch.cuda.is_available():
        print("CUDA not available.")
        return

    props = torch.cuda.get_device_properties(0)
    print("=" * 85)
    print("GENERATION BENCHMARK — Goliath FP4/FP8 Configs")
    print("=" * 85)
    print(f"GPU: {props.name}")
    print(f"VRAM: {props.total_memory / 1e9:.1f} GB")
    print(f"Goliath available: {_GOLIATH_AVAILABLE}")
    if _can_use_goliath_dot_scaled is not None:
        print(f"Goliath dot_scaled (native FP4 TCs): {_can_use_goliath_dot_scaled()}")
    print()

    configs = [
        ('Goliath FP4',  dict(goliath_bits=4,      use_goliath=True)),
        ('Goliath FP8',  dict(goliath_bits=8,      use_goliath=True)),
        ('Goliath Auto', dict(goliath_bits='auto', use_goliath=True)),
        ('Legacy path',  dict(goliath_bits=4,      use_goliath=False)),
    ]

    context_lengths = [512, 2048, 8192]

    # --- Prefill benchmark ---
    print("-" * 85)
    print("PREFILL BENCHMARK")
    print("-" * 85)
    header = f"{'Config':<16} | {'Ctx':>5} | {'Prefill ms':>11} | {'Prefill tok/s':>14} | {'VRAM GB':>8}"
    print(header)
    print("-" * len(header))

    for cfg_name, cfg_kwargs in configs:
        try:
            engine = create_bench_engine(**cfg_kwargs)
        except Exception as e:
            print(f"{cfg_name:<16} | {'ERROR':>5} | {str(e)[:40]}")
            continue

        for ctx in context_lengths:
            try:
                r = bench_prefill(engine, ctx)
                print(f"{cfg_name:<16} | {ctx:>5} | {r['ms']:>9.2f}ms | {r['tok_s']:>12,.0f} | {r['vram_gb']:>7.2f}")
            except Exception as e:
                print(f"{cfg_name:<16} | {ctx:>5} | ERROR: {str(e)[:30]}")

        del engine
        gc.collect()
        torch.cuda.empty_cache()

    # --- Decode benchmark ---
    print()
    print("-" * 85)
    print("DECODE BENCHMARK (50 tokens)")
    print("-" * 85)
    header = f"{'Config':<16} | {'Ctx':>5} | {'Decode tok/s':>13} | {'ms/token':>9} | {'VRAM GB':>8}"
    print(header)
    print("-" * len(header))

    for cfg_name, cfg_kwargs in configs:
        try:
            engine = create_bench_engine(**cfg_kwargs)
        except Exception as e:
            print(f"{cfg_name:<16} | {'ERROR':>5} | {str(e)[:40]}")
            continue

        for ctx in context_lengths:
            try:
                r = bench_decode(engine, ctx)
                print(f"{cfg_name:<16} | {ctx:>5} | {r['tok_s']:>11,.1f} | {r['per_token_ms']:>7.2f}ms | {r['vram_gb']:>7.2f}")
            except Exception as e:
                print(f"{cfg_name:<16} | {ctx:>5} | ERROR: {str(e)[:30]}")

        del engine
        gc.collect()
        torch.cuda.empty_cache()

    print()
    print("=" * 85)
    print("BENCHMARK COMPLETE")
    print("=" * 85)


if __name__ == "__main__":
    main()