File size: 20,286 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
import torch
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))

from arbitor.kernel import ternary_scale as tscale
from arbitor.kernel.ternary_scale import TernaryScaleTensor, TScaleType, TILE_SIZE, GROUP_SIZES
from arbitor.optim.sign_sgd import SignSGD
from arbitor.components import StickyZoneSTE
from arbitor.config import VOCAB, CTX, SPECIAL_VOCAB
from arbitor.main import ARBModel


# ─── TernaryScaleTensor Tests ───

def test_tscale_shape():
    lin = TernaryScaleTensor(32, 16)
    x = torch.randn(2, 10, 32)
    out = lin(x)
    assert out.shape == (2, 10, 16), f"Shape: {out.shape}"
    print(" PASS test_tscale_shape")


def test_tscale_ternary_output():
    lin = TernaryScaleTensor(32, 16, threshold=0.05)
    T = lin._get_T()
    unique = set(T.detach().flatten().tolist())
    assert unique.issubset({-1, 0, 1}), f"Non-ternary values in T: {unique}"
    print(" PASS test_tscale_ternary_output")


def test_tscale_T64_per_element_s():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
    dq = lin.dequantize()
    assert dq.shape == (16, 32), f"Dequantize shape: {dq.shape}"
    print(" PASS test_tscale_T64_per_element_s")


def test_tscale_T32_group_s():
    lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T32)
    dq = lin.dequantize()
    gpr = lin.E.shape[0] // lin.out_dim
    assert gpr > 0, f"Groups per row: {gpr}"
    assert dq.shape == (16, 96), f"Dequantize shape: {dq.shape}"
    print(" PASS test_tscale_T32_group_s")


def test_tscale_to_switching():
    lin = TernaryScaleTensor(96, 16, tscale_type=TScaleType.T64)
    dq_before = lin.dequantize()
    assert lin.tscale_type == TScaleType.T64
    lin.tscale_to(TScaleType.T32)
    assert lin.tscale_type == TScaleType.T32
    dq_after = lin.dequantize()
    assert dq_before.shape == dq_after.shape
    lin.tscale_to(TScaleType.T4)
    assert lin.tscale_type == TScaleType.T4
    dq_t4 = lin.dequantize()
    assert dq_t4.shape == dq_before.shape
    print(" PASS test_tscale_to_switching")


def test_tscale_cast_alias():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T64)
    result = lin.tscale_cast(TScaleType.T8)
    assert result is lin, "tscale_cast should return self"
    assert lin.tscale_type == TScaleType.T8
    print(" PASS test_tscale_cast_alias")


def test_tscale_gradient_flow():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    x = torch.randn(2, 10, 32)
    x.requires_grad_(True)
    out = lin(x)
    out.sum().backward()
    assert x.grad is not None, "No gradient on input"
    print(" PASS test_tscale_gradient_flow")


def test_tscale_all_types_forward():
    for tscale_type in TScaleType:
        lin = TernaryScaleTensor(96, 16, tscale_type=tscale_type)
        x = torch.randn(2, 4, 96)
        out = lin(x)
        assert out.shape == (2, 4, 16), f"{tscale_type.name}: shape {out.shape}"
        assert torch.isfinite(out).all(), f"{tscale_type.name}: non-finite output"
    print(" PASS test_tscale_all_types_forward")


def test_tscale_dequantize():
    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32)
    w_eff = lin.dequantize()
    assert w_eff.shape == (16, 32), f"Shape: {w_eff.shape}"
    assert torch.isfinite(w_eff).all()
    print(" PASS test_tscale_dequantize")


def test_tscale_effective_bpw():
    lin64 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T64)
    lin4 = TernaryScaleTensor(384, 384, tscale_type=TScaleType.T4)
    assert lin4.effective_bpw < lin64.effective_bpw, "T4 should have lower BPW than T64"
    print(f"   T64 BPW: {lin64.effective_bpw:.2f}, T4 BPW: {lin4.effective_bpw:.2f}")
    print(" PASS test_tscale_effective_bpw")


def test_tscale_model_integration():
    for tscale_type in [TScaleType.T64, TScaleType.T32, TScaleType.T8]:
        model = ARBModel(tscale_type=tscale_type)
        x = torch.randint(0, VOCAB, (2, 10))
        logits, losses, _, _ = model(x, targets=x[:, 3:])
        assert losses is not None
        losses.total.backward()
    print(" PASS test_tscale_model_integration")


def test_tscale_runtime_switch():
    model = ARBModel(tscale_type=TScaleType.T64)
    x = torch.randint(0, VOCAB, (1, 10))

    logits64, _, _, _ = model(x)
    for module in model.modules():
        if isinstance(module, TernaryScaleTensor):
            module.tscale_to(TScaleType.T4)
    logits4, _, _, _ = model(x)

    assert torch.isfinite(logits4).all(), "Non-finite after tscale.to(T4)"
    assert logits4.shape == logits64.shape, "Shape mismatch after tscale switch"
    print(" PASS test_tscale_runtime_switch")


# ─── SignSGD Tests ───

def test_sign_sgd_step():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    x = torch.randn(2, 10)
    loss = model(x).sum()
    loss.backward()
    w_before = model.weight.clone()
    optimizer.step()
    assert not torch.equal(model.weight, w_before), "Weights did not change"
    print(" PASS test_sign_sgd_step")


def test_sign_sgd_no_momentum():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    assert len(optimizer.state) == 0, "SignSGD should have no state (no momentum)"
    print(" PASS test_sign_sgd_no_momentum")


def test_sign_sgd_memory():
    model = torch.nn.Linear(100, 100)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    mem = optimizer.get_memory_mb()
    assert mem > 0, "Memory should be positive"
    print(f"   SignSGD memory: {mem:.2f} MB")
    print(" PASS test_sign_sgd_memory")


def test_sign_sgd_with_tscale_model():
    model = ARBModel(tscale_type=TScaleType.T32)
    optimizer = SignSGD(model.parameters(), lr=0.01)
    x = torch.randint(0, VOCAB, (2, 10))
    logits, losses, _, _ = model(x, targets=x[:, 3:])
    losses.total.backward()
    optimizer.step()
    model._ternary_update_memory()
    assert len(optimizer.state) == 0, "SignSGD should have no state"
    print(" PASS test_sign_sgd_with_tscale_model")


def test_sign_sgd_weight_decay():
    model = torch.nn.Linear(10, 5)
    optimizer = SignSGD(model.parameters(), lr=0.01, weight_decay=0.01)
    x = torch.randn(2, 10)
    loss = model(x).sum()
    loss.backward()
    w_before = model.weight.clone()
    optimizer.step()
    w_diff = (model.weight - w_before).abs().sum().item()
    assert w_diff > 0, "Weights should change with weight_decay"
    print(" PASS test_sign_sgd_weight_decay")


# ─── TileLang PyTorch Reference Tests ───

def test_dequant_gemm_pytorch_ref():
    import importlib.util
    kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
    if not os.path.exists(kernel_path):
        print(" SKIP test_dequant_gemm_pytorch_ref (tilelang reference file missing)")
        return
    spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref

    M, N, K, group_size = 4, 8, 96, 12
    signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
    exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
    x = torch.randn(M, K, dtype=torch.float16)

    output = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)
    assert output.shape == (M, N), f"Shape: {output.shape}"
    assert torch.isfinite(output).all(), "Non-finite output"
    print(" PASS test_dequant_gemm_pytorch_ref")


def test_dequant_gemm_matches_manual():
    import importlib.util
    import torch.nn.functional as F
    kernel_path = os.path.join(os.path.dirname(__file__), "..", "tilelang", "kernels", "dequant_gemm.py")
    if not os.path.exists(kernel_path):
        print(" SKIP test_dequant_gemm_matches_manual (tilelang reference file missing)")
        return
    spec = importlib.util.spec_from_file_location("dequant_gemm", kernel_path)
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    dequant_gemm_pytorch_ref = mod.dequant_gemm_pytorch_ref

    M, N, K, group_size = 2, 4, 48, 12
    signs = torch.randint(-1, 2, (N, K), dtype=torch.int8)
    exponents = torch.randint(-3, 4, (N, K // group_size), dtype=torch.int8)
    x = torch.randn(M, K, dtype=torch.float16)

    result = dequant_gemm_pytorch_ref(signs, exponents, x, group_size)

    exp_expanded = exponents.repeat_interleave(group_size, dim=1)
    pos_mask = exp_expanded >= 0
    two_pow = torch.where(pos_mask,
                           (1 << exp_expanded.to(torch.int32)).to(torch.float16),
                           (1 >> (-exp_expanded.to(torch.int32))).to(torch.float16))
    w = signs.to(torch.float16) * two_pow
    expected = x @ w.t()

    assert torch.allclose(result, expected, atol=1e-3), "PyTorch ref mismatch"
    print(" PASS test_dequant_gemm_matches_manual")


# ─── Integration: SignSGD + TernaryScaleTensor training step ───

def test_full_training_step():
    model = ARBModel(tscale_type=TScaleType.T32)
    optimizer = SignSGD(model.parameters(), lr=0.01)

    x = torch.randint(0, VOCAB, (2, 10))
    logits, losses, _, _ = model(x, targets=x[:, 3:])
    losses.total.backward()
    optimizer.step()
    model._ternary_update_memory()

    logits2, losses2, _, _ = model(x, targets=x[:, 3:])
    assert torch.isfinite(losses2.total), "Non-finite loss after step"
    print(" PASS test_full_training_step")


def test_multiple_steps_converge():
    model = ARBModel(tscale_type=TScaleType.T32)
    optimizer = SignSGD(model.parameters(), lr=0.001)

    x = torch.randint(0, VOCAB, (4, 10))
    losses = []
    for step in range(50):
        optimizer.zero_grad()
        logits, losses_out, _, _ = model(x, targets=x[:, 3:])
        loss_val = losses_out.total
        loss_val.backward()
        optimizer.step()
        model._ternary_update_memory(accum_threshold=3)
        losses.append(loss_val.item())

    assert torch.isfinite(torch.tensor(losses)).all(), "Non-finite loss during training"
    print(f"   Loss range: {min(losses):.4f} – {max(losses):.4f} over 50 steps")
    print(" PASS test_multiple_steps_converge")


def test_cuda_triton_correctness_linear():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_linear (CUDA/Triton unavailable)")
        return
    from arbitor.kernel.ternary_scale import TernaryRMSNorm, _triton_ternary_embed
    from arbitor.main import ByteEmbedding
    ATOL = 1e-3
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        x = torch.randn(4, 4, 32, requires_grad=True)
        cpu_out = lin_cpu(x)
        grad_out = torch.randn_like(cpu_out)
        cpu_out.backward(grad_out)
        cpu_grad_x = x.grad.clone()

        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())
        x_gpu = x.detach().clone().cuda().requires_grad_(True)
        gpu_out = lin_gpu(x_gpu)
        gpu_out.backward(grad_out.cuda())
        gpu_grad_x = x_gpu.grad.clone()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
        assert fwd_diff < ATOL, f"{tt.name} fwd_diff={fwd_diff}"
        assert bwd_diff < ATOL, f"{tt.name} bwd_diff={bwd_diff}"
    print(" PASS test_cuda_triton_correctness_linear")


def test_cuda_triton_correctness_rmsnorm():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_rmsnorm (CUDA/Triton unavailable)")
        return
    from arbitor.kernel.ternary_scale import TernaryRMSNorm
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        norm_cpu = TernaryRMSNorm(256, tscale_type=tt)
        x = torch.randn(2, 4, 256, requires_grad=True)
        cpu_out = norm_cpu(x)
        cpu_out.sum().backward()
        cpu_grad_x = x.grad.clone()

        norm_gpu = TernaryRMSNorm(256, tscale_type=tt).cuda()
        norm_gpu.load_state_dict(norm_cpu.state_dict())
        x_gpu = x.detach().clone().cuda().requires_grad_(True)
        gpu_out = norm_gpu(x_gpu)
        gpu_out.sum().backward()
        gpu_grad_x = x_gpu.grad.clone()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        bwd_diff = (cpu_grad_x - gpu_grad_x.cpu()).abs().max().item()
        assert fwd_diff < 1e-5, f"{tt.name} rmsnorm fwd_diff={fwd_diff}"
        assert bwd_diff < 1e-5, f"{tt.name} rmsnorm bwd_diff={bwd_diff}"
    print(" PASS test_cuda_triton_correctness_rmsnorm")


def test_cuda_triton_correctness_embedding():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_embedding (CUDA/Triton unavailable)")
        return
    from arbitor.main import ByteEmbedding
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        emb_cpu = ByteEmbedding(tscale_type=tt)
        x = torch.tensor([0, 1, 2, 5, 10])
        cpu_out = emb_cpu(x)
        cpu_out.sum().backward()

        emb_gpu = ByteEmbedding(tscale_type=tt).cuda()
        emb_gpu.load_state_dict(emb_cpu.state_dict())
        x_gpu = x.cuda()
        gpu_out = emb_gpu(x_gpu)
        gpu_out.sum().backward()

        fwd_diff = (cpu_out - gpu_out.cpu()).abs().max().item()
        assert fwd_diff < 1e-5, f"{tt.name} embed fwd_diff={fwd_diff}"
        if hasattr(emb_cpu, '_hook_grad_T_sign') and hasattr(emb_gpu, '_hook_grad_T_sign'):
            gs_match = (emb_gpu._hook_grad_T_sign.cpu() == emb_cpu._hook_grad_T_sign).float().mean().item()
            assert gs_match > 0.99, f"{tt.name} embed grad_sign match={gs_match}"
    print(" PASS test_cuda_triton_correctness_embedding")


def test_cuda_triton_correctness_update_E():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_update_E (CUDA/Triton unavailable)")
        return
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())

        x_cpu = torch.randn(4, 4, 32, requires_grad=True)
        x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)

        cpu_out = lin_cpu(x_cpu)
        cpu_out.sum().backward()
        lin_cpu.update_E()
        E_cpu = lin_cpu.E.clone()
        E_accum_cpu = lin_cpu.E_accum.clone()
        gpu_out = lin_gpu(x_gpu)
        gpu_out.sum().backward()
        lin_gpu.update_E()
        E_gpu = lin_gpu.E.clone()
        E_accum_gpu = lin_gpu.E_accum.clone()

        # Compare fixed-point E residual update results.
        E_diff = (E_cpu.float() - E_gpu.cpu().float()).abs().max().item()
        assert E_diff < 0.01, f"{tt.name} CPU-GPU E update mismatch: {E_diff}"
        E_accum_diff = (E_accum_cpu.float() - E_accum_gpu.cpu().float()).abs().max().item()
        assert E_accum_diff < 0.01, f"{tt.name} CPU-GPU E_accum update mismatch: {E_accum_diff}"
    print(" PASS test_cuda_triton_correctness_update_E")


def test_cuda_triton_correctness_ternary_step():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_correctness_ternary_step (CUDA/Triton unavailable)")
        return
    for tt in [TScaleType.T4, TScaleType.T6, TScaleType.T8, TScaleType.T16, TScaleType.T32, TScaleType.T64]:
        lin_cpu = TernaryScaleTensor(32, 16, tscale_type=tt)
        lin_gpu = TernaryScaleTensor(32, 16, tscale_type=tt).cuda()
        lin_gpu.load_state_dict(lin_cpu.state_dict())

        x_cpu = torch.randn(4, 4, 32, requires_grad=True)
        x_gpu = x_cpu.detach().clone().cuda().requires_grad_(True)

        cpu_out = lin_cpu(x_cpu)
        cpu_out.sum().backward()
        lin_cpu.ternary_step(accum_threshold=3)
        T_cpu = lin_cpu._get_T().clone()
        Taccum_cpu = lin_cpu.T_accum.clone()

        gpu_out = lin_gpu(x_gpu)
        gpu_out.sum().backward()
        lin_gpu.ternary_step(accum_threshold=3)
        T_gpu = lin_gpu._get_T().clone()
        Taccum_gpu = lin_gpu.T_accum.clone()

        T_match = (T_cpu == T_gpu.cpu()).float().mean().item()
        Taccum_match = (Taccum_cpu == Taccum_gpu.cpu()).float().mean().item()
        assert T_match == 1.0, f"{tt.name} T_match={T_match}"
        assert Taccum_match == 1.0, f"{tt.name} Taccum_match={Taccum_match}"
    print(" PASS test_cuda_triton_correctness_ternary_step")


def test_cuda_triton_tscale_path():
    if not torch.cuda.is_available() or not tscale._HAS_TRITON:
        print(" SKIP test_cuda_triton_tscale_path (CUDA/Triton unavailable)")
        return

    lin = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
    x = torch.randn(2, 4, 32, device="cuda", requires_grad=True)
    out = lin(x)
    assert out.is_cuda, "Triton path should produce CUDA output"
    assert out.shape == (2, 4, 16), f"Shape: {out.shape}"
    grad_out = torch.randn_like(out)
    out.backward(grad_out)
    assert x.grad is not None and x.grad.is_cuda, "CUDA grad_x missing"
    assert lin.T_accum.abs().sum().item() > 0, \
        "Triton path should stream updates into int8 T_accum"
    assert not hasattr(lin, "_hook_grad_T_sign"), \
        "Triton path should not retain full weight-shaped grad-sign hooks"
    assert not hasattr(lin, "_hook_grad_2d") and not hasattr(lin, "_hook_x_2d"), \
        "Triton path should not retain fp32 grad/x views"
    E_accum_before = lin.E_accum.clone()
    torch.cuda.synchronize()
    assert not torch.equal(lin.E_accum, E_accum_before) or lin.E_accum.abs().sum().item() > 0, \
        "Streaming CUDA E update did not modify exponent residual state"
    assert not hasattr(lin, "_hook_grad_T_sign"), \
        "No retained grad-sign hook should remain after streaming backward"
    assert lin.T_packed.is_cuda and lin.E.is_cuda, "Ternary buffers moved off CUDA after update"

    lin_force = TernaryScaleTensor(32, 16, tscale_type=TScaleType.T32).cuda()
    lin_force._hook_grad_2d = torch.ones(2, 16, device="cuda")
    lin_force._hook_x_2d = torch.ones(2, 32, device="cuda")
    lin_force.ternary_step(accum_threshold=0)
    forced_T = lin_force._get_T()
    assert forced_T.is_cuda, "Unpacked CUDA ternary state should stay on CUDA"
    assert (forced_T == -1).all(), "CUDA ternary repack should move positive gradients in descent direction"
    assert lin_force.T_accum.abs().sum().item() == 0, "CUDA ternary repack should reset flipped accumulators"
    print(" PASS test_cuda_triton_tscale_path")


if __name__ == "__main__":
    tests = [
        test_tscale_shape,
        test_tscale_ternary_output,
        test_tscale_T64_per_element_s,
        test_tscale_T32_group_s,
        test_tscale_to_switching,
        test_tscale_cast_alias,
        test_tscale_gradient_flow,
        test_tscale_all_types_forward,
        test_tscale_dequantize,
        test_tscale_effective_bpw,
        test_tscale_model_integration,
        test_tscale_runtime_switch,
        test_sign_sgd_step,
        test_sign_sgd_no_momentum,
        test_sign_sgd_memory,
        test_sign_sgd_with_tscale_model,
        test_sign_sgd_weight_decay,
        test_dequant_gemm_pytorch_ref,
        test_dequant_gemm_matches_manual,
        test_cuda_triton_correctness_linear,
        test_cuda_triton_correctness_rmsnorm,
        test_cuda_triton_correctness_embedding,
        test_cuda_triton_correctness_update_E,
        test_cuda_triton_correctness_ternary_step,
        test_cuda_triton_tscale_path,
        test_full_training_step,
        test_multiple_steps_converge,
    ]
    print("Running TernaryScale + SignSGD + TileLang Phase 2 tests...\n")
    passed = 0
    failed = 0
    for test in tests:
        try:
            test()
            passed += 1
        except Exception as e:
            print(f" FAIL {test.__name__}: {e}")
            import traceback
            traceback.print_exc()
            failed += 1
    print(f"\n{passed} passed, {failed} failed out of {len(tests)} tests")