liangsu9988 commited on
Commit
a6a49dc
·
verified ·
1 Parent(s): ed5bd9e

Uploaded using `kernel-builder`.

Browse files
Files changed (31) hide show
  1. benchmarks/benchmark.py +315 -0
  2. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +123 -0
  3. build/torch210-cxx11-cu128-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  4. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  5. build/torch210-cxx11-cu128-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  6. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +23 -0
  7. build/torch210-cxx11-cu130-x86_64-linux/__init__.py +123 -0
  8. build/torch210-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  9. build/torch210-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  10. build/torch210-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  11. build/torch210-cxx11-cu130-x86_64-linux/metadata.json +21 -0
  12. build/torch211-cxx11-cu128-x86_64-linux/__init__.py +123 -0
  13. build/torch211-cxx11-cu128-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  14. build/torch211-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  15. build/torch211-cxx11-cu128-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  16. build/torch211-cxx11-cu128-x86_64-linux/metadata.json +23 -0
  17. build/torch211-cxx11-cu130-x86_64-linux/__init__.py +123 -0
  18. build/torch211-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  19. build/torch211-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  20. build/torch211-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  21. build/torch211-cxx11-cu130-x86_64-linux/metadata.json +21 -0
  22. build/torch212-cxx11-cu130-x86_64-linux/__init__.py +123 -0
  23. build/torch212-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  24. build/torch212-cxx11-cu130-x86_64-linux/_ops.py +9 -0
  25. build/torch212-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  26. build/torch212-cxx11-cu130-x86_64-linux/metadata.json +21 -0
  27. build/torch212-cxx11-cu132-x86_64-linux/__init__.py +123 -0
  28. build/torch212-cxx11-cu132-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so +3 -0
  29. build/torch212-cxx11-cu132-x86_64-linux/_ops.py +9 -0
  30. build/torch212-cxx11-cu132-x86_64-linux/flashrt_residual_norm_quant/__init__.py +26 -0
  31. build/torch212-cxx11-cu132-x86_64-linux/metadata.json +21 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Benchmark flashrt-residual-norm-quant against PyTorch eager references."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import ctypes
8
+ import ctypes.util
9
+ import importlib
10
+ import json
11
+ import math
12
+ import os
13
+ import sys
14
+ from dataclasses import asdict, dataclass
15
+ from pathlib import Path
16
+
17
+ import torch
18
+
19
+
20
+ ROOT = Path(__file__).resolve().parents[2]
21
+ PACKAGE = ROOT / "flashrt-residual-norm-quant"
22
+ REGISTRATION_INCLUDE = (
23
+ ROOT.parent
24
+ / "kernels"
25
+ / "kernel-builder"
26
+ / "src"
27
+ / "pyproject"
28
+ / "templates"
29
+ / "torch"
30
+ )
31
+
32
+ SHAPES = {
33
+ "pi05_decoder": (10, 1024),
34
+ "pi05_vision": (512, 1152),
35
+ "groot_vl": (1024, 2048),
36
+ "video_prefill": (2520, 2048),
37
+ }
38
+ SHAPE_GROUPS = {
39
+ "smoke": ["pi05_decoder"],
40
+ "headline": ["pi05_decoder", "pi05_vision", "groot_vl"],
41
+ "all": list(SHAPES.keys()),
42
+ }
43
+
44
+
45
+ @dataclass
46
+ class Result:
47
+ shape: str
48
+ rows: int
49
+ dim: int
50
+ kernel: str
51
+ flashrt_us: float
52
+ torch_eager_us: float
53
+ speedup_vs_eager: float
54
+ max_abs: float
55
+ mean_abs: float
56
+ p99_abs: float
57
+ cosine: float
58
+ status: str
59
+
60
+
61
+ class SourceOps:
62
+ def __init__(self, namespace: str) -> None:
63
+ self._ops = getattr(torch.ops, namespace)
64
+
65
+ def rms_norm_quant_fp8_static_bf16(self, x, weight, scale, eps=1e-6, out=None):
66
+ if out is None:
67
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
68
+ self._ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
69
+ return out
70
+
71
+ def residual_add_rms_norm_quant_fp8_static_bf16(
72
+ self, residual, x, weight, scale, eps=1e-6, out=None
73
+ ):
74
+ if out is None:
75
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
76
+ self._ops.residual_add_rms_norm_quant_fp8_static_bf16(
77
+ residual, x, weight, scale, float(eps), out
78
+ )
79
+ return out
80
+
81
+
82
+ def _preload_cublaslt() -> None:
83
+ for parent in Path(torch.__file__).resolve().parents:
84
+ candidate = parent / "nvidia" / "cublas" / "lib" / "libcublasLt.so.12"
85
+ if candidate.exists():
86
+ ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
87
+ return
88
+ library = ctypes.util.find_library("cublasLt")
89
+ if library:
90
+ ctypes.CDLL(library, mode=ctypes.RTLD_GLOBAL)
91
+
92
+
93
+ def _current_arch_list() -> str:
94
+ major, minor = torch.cuda.get_device_capability(0)
95
+ return f"{major}.{minor}"
96
+
97
+
98
+ def load_source_ops() -> SourceOps:
99
+ from torch.utils.cpp_extension import load
100
+
101
+ if not REGISTRATION_INCLUDE.is_dir():
102
+ raise RuntimeError(f"missing kernel-builder registration include: {REGISTRATION_INCLUDE}")
103
+ _preload_cublaslt()
104
+ os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
105
+ namespace = "flashrt_residual_norm_quant_benchmark"
106
+ load(
107
+ name=namespace,
108
+ sources=[
109
+ str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
110
+ str(PACKAGE / "csrc" / "residual_norm_quant.cu"),
111
+ ],
112
+ extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
113
+ extra_cflags=["-O3", "-DCUDA_KERNEL"],
114
+ extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
115
+ verbose=False,
116
+ )
117
+ return SourceOps(namespace)
118
+
119
+
120
+ def load_installed_ops(artifact: str | None):
121
+ if artifact:
122
+ sys.path.insert(0, artifact)
123
+ try:
124
+ return importlib.import_module("flashrt_residual_norm_quant")
125
+ finally:
126
+ if artifact:
127
+ sys.path.remove(artifact)
128
+
129
+
130
+ def quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
131
+ return torch.clamp(x.float() / scale.float(), -448.0, 448.0).to(torch.float8_e4m3fn)
132
+
133
+
134
+ def torch_rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
135
+ rms = torch.rsqrt(torch.mean(x.float() * x.float(), dim=1, keepdim=True) + eps)
136
+ return x.float() * rms * weight.float()
137
+
138
+
139
+ def torch_rms_norm_quant(x, weight, scale, eps) -> torch.Tensor:
140
+ return quantize_fp8(torch_rms_norm(x, weight, eps), scale)
141
+
142
+
143
+ def torch_residual_add_rms_norm_quant(residual, x, weight, scale, eps) -> torch.Tensor:
144
+ added = residual.float() + x.float()
145
+ residual.copy_(added.to(torch.bfloat16))
146
+ rms = torch.rsqrt(torch.mean(added * added, dim=1, keepdim=True) + eps)
147
+ return quantize_fp8(residual.float() * rms * weight.float(), scale)
148
+
149
+
150
+ def make_case(rows: int, dim: int):
151
+ x = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16)
152
+ residual = torch.randn((rows, dim), device="cuda", dtype=torch.bfloat16)
153
+ weight = (1.0 + 0.1 * torch.randn((dim,), device="cuda", dtype=torch.bfloat16)).contiguous()
154
+ scale = torch.tensor([0.04], device="cuda", dtype=torch.float32)
155
+ out = torch.empty((rows, dim), device="cuda", dtype=torch.float8_e4m3fn)
156
+ return x, residual, weight, scale, out
157
+
158
+
159
+ def time_us(fn, warmup: int, iters: int) -> float:
160
+ for _ in range(warmup):
161
+ fn()
162
+ torch.cuda.synchronize()
163
+ start = torch.cuda.Event(enable_timing=True)
164
+ end = torch.cuda.Event(enable_timing=True)
165
+ start.record()
166
+ for _ in range(iters):
167
+ fn()
168
+ end.record()
169
+ torch.cuda.synchronize()
170
+ return start.elapsed_time(end) * 1000.0 / iters
171
+
172
+
173
+ def percentile(x: torch.Tensor, q: float) -> torch.Tensor:
174
+ flat = x.flatten()
175
+ k = max(1, min(flat.numel(), math.ceil(q * flat.numel())))
176
+ return flat.kthvalue(k).values
177
+
178
+
179
+ def metrics(got: torch.Tensor, expected: torch.Tensor):
180
+ diff = (got.float() - expected.float()).abs().flatten()
181
+ cosine = torch.nn.functional.cosine_similarity(
182
+ got.float().flatten(), expected.float().flatten(), dim=0
183
+ )
184
+ return {
185
+ "max_abs": float(diff.max().item()),
186
+ "mean_abs": float(diff.mean().item()),
187
+ "p99_abs": float(percentile(diff, 0.99).item()),
188
+ "cosine": float(cosine.item()),
189
+ }
190
+
191
+
192
+ def run_one(ops, name: str, rows: int, dim: int, args) -> list[Result]:
193
+ x, residual, weight, scale, out = make_case(rows, dim)
194
+ eps = args.eps
195
+ results = []
196
+
197
+ got = ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out)
198
+ expected = torch_rms_norm_quant(x, weight, scale, eps)
199
+ m = metrics(got, expected)
200
+ kernel_us = time_us(
201
+ lambda: ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, eps, out),
202
+ args.warmup,
203
+ args.iters,
204
+ )
205
+ torch_us = time_us(lambda: torch_rms_norm_quant(x, weight, scale, eps), args.warmup, args.iters)
206
+ results.append(
207
+ Result(
208
+ shape=name,
209
+ rows=rows,
210
+ dim=dim,
211
+ kernel="rms_norm_quant_fp8_static_bf16",
212
+ flashrt_us=kernel_us,
213
+ torch_eager_us=torch_us,
214
+ speedup_vs_eager=torch_us / kernel_us,
215
+ status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL",
216
+ **m,
217
+ )
218
+ )
219
+
220
+ residual0 = residual.clone()
221
+ residual_kernel = residual0.clone()
222
+ got = ops.residual_add_rms_norm_quant_fp8_static_bf16(
223
+ residual_kernel, x, weight, scale, eps, out
224
+ )
225
+ residual_ref = residual0.clone()
226
+ expected = torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps)
227
+ m = metrics(got, expected)
228
+ residual_kernel = residual0.clone()
229
+ residual_ref = residual0.clone()
230
+ kernel_us = time_us(
231
+ lambda: ops.residual_add_rms_norm_quant_fp8_static_bf16(
232
+ residual_kernel, x, weight, scale, eps, out
233
+ ),
234
+ args.warmup,
235
+ args.iters,
236
+ )
237
+ torch_us = time_us(
238
+ lambda: torch_residual_add_rms_norm_quant(residual_ref, x, weight, scale, eps),
239
+ args.warmup,
240
+ args.iters,
241
+ )
242
+ results.append(
243
+ Result(
244
+ shape=name,
245
+ rows=rows,
246
+ dim=dim,
247
+ kernel="residual_add_rms_norm_quant_fp8_static_bf16",
248
+ flashrt_us=kernel_us,
249
+ torch_eager_us=torch_us,
250
+ speedup_vs_eager=torch_us / kernel_us,
251
+ status="PASS" if m["p99_abs"] <= args.p99_abs_limit else "FAIL",
252
+ **m,
253
+ )
254
+ )
255
+ return results
256
+
257
+
258
+ def write_markdown(path: Path, results: list[Result]) -> None:
259
+ lines = [
260
+ "| Shape | Rows,Dim | Kernel | FlashRT us | Eager us | vs eager | Max abs | Mean abs | P99 abs | Cosine | Status |",
261
+ "|---|---:|---|---:|---:|---:|---:|---:|---:|---:|---|",
262
+ ]
263
+ for r in results:
264
+ lines.append(
265
+ f"| {r.shape} | {r.rows},{r.dim} | {r.kernel} | {r.flashrt_us:.3f} | "
266
+ f"{r.torch_eager_us:.3f} | {r.speedup_vs_eager:.2f}x | "
267
+ f"{r.max_abs:.6f} | {r.mean_abs:.6f} | {r.p99_abs:.6f} | "
268
+ f"{r.cosine:.8f} | {r.status} |"
269
+ )
270
+ path.write_text("\n".join(lines) + "\n")
271
+
272
+
273
+ def main() -> None:
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--backend", choices=["source", "installed"], default="source")
276
+ parser.add_argument("--artifact", default=None)
277
+ parser.add_argument("--shapes", choices=sorted(SHAPE_GROUPS), default="smoke")
278
+ parser.add_argument("--warmup", type=int, default=5)
279
+ parser.add_argument("--iters", type=int, default=20)
280
+ parser.add_argument("--eps", type=float, default=1e-6)
281
+ parser.add_argument("--p99-abs-limit", type=float, default=0.5)
282
+ parser.add_argument("--output", default=None)
283
+ parser.add_argument("--markdown", default=None)
284
+ args = parser.parse_args()
285
+
286
+ if not torch.cuda.is_available():
287
+ raise SystemExit("CUDA is required")
288
+ torch.manual_seed(29)
289
+ ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
290
+
291
+ results = []
292
+ for name in SHAPE_GROUPS[args.shapes]:
293
+ rows, dim = SHAPES[name]
294
+ results.extend(run_one(ops, name, rows, dim, args))
295
+
296
+ for r in results:
297
+ print(
298
+ f"{r.status} {r.shape}/{r.kernel}: flashrt={r.flashrt_us:.3f}us "
299
+ f"eager={r.torch_eager_us:.3f}us speedup={r.speedup_vs_eager:.2f}x "
300
+ f"p99_abs={r.p99_abs:.6f} cosine={r.cosine:.8f}"
301
+ )
302
+
303
+ if args.output:
304
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
305
+ Path(args.output).write_text(json.dumps([asdict(r) for r in results], indent=2) + "\n")
306
+ if args.markdown:
307
+ Path(args.markdown).parent.mkdir(parents=True, exist_ok=True)
308
+ write_markdown(Path(args.markdown), results)
309
+
310
+ if any(r.status != "PASS" for r in results):
311
+ raise SystemExit(1)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch210-cxx11-cu128-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:244000c5b33e1f609987b8b9aef434d0d6bee50bdf5287442ac889b2ac0c0df4
3
+ size 2471360
build/torch210-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "10.1",
12
+ "12.0+PTX",
13
+ "7.0",
14
+ "7.2",
15
+ "7.5",
16
+ "8.0",
17
+ "8.6",
18
+ "8.7",
19
+ "8.9",
20
+ "9.0"
21
+ ]
22
+ }
23
+ }
build/torch210-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch210-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac048c6ebb52c526e68aa3e2325e0bc28fcd0bf11ceabc084c26b7c1dcb7710
3
+ size 2414152
build/torch210-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch210-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "11.0",
12
+ "12.0+PTX",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }
build/torch211-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch211-cxx11-cu128-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:562ffcdba2e68ce168e6b0be94d3ace7b54ee7f5f3cb701850bcf90b93f2f106
3
+ size 2464400
build/torch211-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch211-cxx11-cu128-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "10.1",
12
+ "12.0+PTX",
13
+ "7.0",
14
+ "7.2",
15
+ "7.5",
16
+ "8.0",
17
+ "8.6",
18
+ "8.7",
19
+ "8.9",
20
+ "9.0"
21
+ ]
22
+ }
23
+ }
build/torch211-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch211-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b197c38ab8ee5e98974c81fe7a883057c027bb0238c737e2494ceb627872af6f
3
+ size 2398992
build/torch211-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch211-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "11.0",
12
+ "12.0+PTX",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }
build/torch212-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch212-cxx11-cu130-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c386ef3e62e6314e355fb016afc4bb3536653e99749178e29900ff04278bc5cc
3
+ size 2400424
build/torch212-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch212-cxx11-cu130-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "11.0",
12
+ "12.0+PTX",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }
build/torch212-cxx11-cu132-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT residual/RMSNorm/static-FP8 quantization kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ def _check_rank2_same_shape(x: torch.Tensor, out: torch.Tensor, out_name: str) -> None:
11
+ if x.dim() != 2:
12
+ raise RuntimeError("x must be rank-2")
13
+ if out.shape != x.shape:
14
+ raise RuntimeError(f"{out_name} must have the same shape as x")
15
+
16
+
17
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_bf16"))
18
+ def _rms_norm_bf16_fake(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float,
22
+ out: torch.Tensor,
23
+ ) -> None:
24
+ _check_rank2_same_shape(x, out, "out")
25
+ if weight.shape != (x.shape[1],):
26
+ raise RuntimeError("weight must have shape (x.shape[1],)")
27
+ return None
28
+
29
+
30
+ @torch.library.register_fake(add_op_namespace_prefix("rms_norm_quant_fp8_static_bf16"))
31
+ def _rms_norm_quant_fp8_static_bf16_fake(
32
+ x: torch.Tensor,
33
+ weight: torch.Tensor,
34
+ scale: torch.Tensor,
35
+ eps: float,
36
+ out: torch.Tensor,
37
+ ) -> None:
38
+ _check_rank2_same_shape(x, out, "out")
39
+ if weight.shape != (x.shape[1],):
40
+ raise RuntimeError("weight must have shape (x.shape[1],)")
41
+ if scale.numel() != 1:
42
+ raise RuntimeError("scale must contain exactly one value")
43
+ return None
44
+
45
+
46
+ @torch.library.register_fake(
47
+ add_op_namespace_prefix("residual_add_rms_norm_quant_fp8_static_bf16")
48
+ )
49
+ def _residual_add_rms_norm_quant_fp8_static_bf16_fake(
50
+ residual: torch.Tensor,
51
+ x: torch.Tensor,
52
+ weight: torch.Tensor,
53
+ scale: torch.Tensor,
54
+ eps: float,
55
+ out: torch.Tensor,
56
+ ) -> None:
57
+ if residual.shape != x.shape:
58
+ raise RuntimeError("residual and x must have the same shape")
59
+ _check_rank2_same_shape(x, out, "out")
60
+ if weight.shape != (x.shape[1],):
61
+ raise RuntimeError("weight must have shape (x.shape[1],)")
62
+ if scale.numel() != 1:
63
+ raise RuntimeError("scale must contain exactly one value")
64
+ return None
65
+
66
+
67
+ def rms_norm_bf16(
68
+ x: torch.Tensor,
69
+ weight: torch.Tensor,
70
+ eps: float = 1e-6,
71
+ out: torch.Tensor | None = None,
72
+ ) -> torch.Tensor:
73
+ """BF16 RMSNorm with affine weight."""
74
+
75
+ if out is None:
76
+ out = torch.empty_like(x, dtype=torch.bfloat16)
77
+ ops.rms_norm_bf16(x, weight, float(eps), out)
78
+ return out
79
+
80
+
81
+ def rms_norm_quant_fp8_static_bf16(
82
+ x: torch.Tensor,
83
+ weight: torch.Tensor,
84
+ scale: torch.Tensor,
85
+ eps: float = 1e-6,
86
+ out: torch.Tensor | None = None,
87
+ ) -> torch.Tensor:
88
+ """BF16 RMSNorm followed by static-scale FP8 E4M3 quantization."""
89
+
90
+ if out is None:
91
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
92
+ ops.rms_norm_quant_fp8_static_bf16(x, weight, scale, float(eps), out)
93
+ return out
94
+
95
+
96
+ def residual_add_rms_norm_quant_fp8_static_bf16(
97
+ residual: torch.Tensor,
98
+ x: torch.Tensor,
99
+ weight: torch.Tensor,
100
+ scale: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ out: torch.Tensor | None = None,
103
+ ) -> torch.Tensor:
104
+ """In-place ``residual += x`` then RMSNorm and static FP8 quantization."""
105
+
106
+ if out is None:
107
+ out = torch.empty_like(x, dtype=torch.float8_e4m3fn)
108
+ ops.residual_add_rms_norm_quant_fp8_static_bf16(
109
+ residual,
110
+ x,
111
+ weight,
112
+ scale,
113
+ float(eps),
114
+ out,
115
+ )
116
+ return out
117
+
118
+
119
+ __all__ = [
120
+ "residual_add_rms_norm_quant_fp8_static_bf16",
121
+ "rms_norm_bf16",
122
+ "rms_norm_quant_fp8_static_bf16",
123
+ ]
build/torch212-cxx11-cu132-x86_64-linux/_flashrt_residual_norm_quant_cuda_cf903dd.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ee5cc9d8958c8f3c04f62671615ec5e5f1969eeaaf5687ea31ccfcca98609d3
3
+ size 2400392
build/torch212-cxx11-cu132-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flashrt_residual_norm_quant_cuda_cf903dd
3
+ ops = torch.ops._flashrt_residual_norm_quant_cuda_cf903dd
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flashrt_residual_norm_quant_cuda_cf903dd::{op_name}"
build/torch212-cxx11-cu132-x86_64-linux/flashrt_residual_norm_quant/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu132-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "flashrt-residual-norm-quant",
3
+ "id": "_flashrt_residual_norm_quant_cuda_cf903dd",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "10.0",
11
+ "11.0",
12
+ "12.0+PTX",
13
+ "7.5",
14
+ "8.0",
15
+ "8.6",
16
+ "8.7",
17
+ "8.9",
18
+ "9.0"
19
+ ]
20
+ }
21
+ }