liangsu9988 commited on
Commit
166f09b
·
verified ·
1 Parent(s): aea72f9

Uploaded using `kernel-builder`.

Browse files
benchmarks/benchmark.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Benchmark fp8-gemm."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import importlib
8
+ import json
9
+ import os
10
+ import sys
11
+ from dataclasses import asdict, dataclass
12
+ from pathlib import Path
13
+
14
+ import torch
15
+
16
+
17
+ ROOT = Path(__file__).resolve().parents[2]
18
+ PACKAGE = ROOT / "fp8-gemm"
19
+ REGISTRATION_INCLUDE = (
20
+ ROOT.parent
21
+ / "kernels"
22
+ / "kernel-builder"
23
+ / "src"
24
+ / "pyproject"
25
+ / "templates"
26
+ / "torch"
27
+ )
28
+
29
+ SHAPES = {
30
+ "decode_m1_k4096_n2048": (1, 4096, 2048),
31
+ "decode_m1_k4096_n8192": (1, 4096, 8192),
32
+ "small_m16_k4096_n4096": (16, 4096, 4096),
33
+ "small_m32_k4096_n8192": (32, 4096, 8192),
34
+ "small_m64_k512_n1024": (64, 512, 1024),
35
+ }
36
+
37
+ MODES = {
38
+ "smoke": ["decode_m1_k4096_n2048", "small_m16_k4096_n4096"],
39
+ "headline": list(SHAPES),
40
+ }
41
+
42
+
43
+ @dataclass
44
+ class Result:
45
+ shape: str
46
+ M: int
47
+ K: int
48
+ N: int
49
+ variant: int
50
+ tile: str
51
+ flashrt_us: float
52
+ torch_eager_us: float
53
+ torch_compile_us: float | None
54
+ speedup_vs_eager: float
55
+ speedup_vs_compile: float | None
56
+ max_abs: float
57
+ mean_abs: float
58
+ p99_abs: float
59
+ cosine: float
60
+ status: str
61
+
62
+
63
+ class SourceOps:
64
+ def __init__(self, namespace: str) -> None:
65
+ self._ops = getattr(torch.ops, namespace)
66
+
67
+ @staticmethod
68
+ def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
69
+ return select_tile(m, n, k, variant)
70
+
71
+ def fp8_linear_bf16(self, x, w, alpha=1.0, out=None, variant=0):
72
+ if out is None:
73
+ out = torch.empty((x.shape[0], w.shape[0]), device=x.device, dtype=torch.bfloat16)
74
+ self._ops.fp8_linear_bf16(x, w, float(alpha), int(variant), out)
75
+ return out
76
+
77
+
78
+ def _current_arch_list() -> str:
79
+ major, minor = torch.cuda.get_device_capability(0)
80
+ if major >= 12:
81
+ return "12.0a"
82
+ return f"{major}.{minor}"
83
+
84
+
85
+ def load_source_ops() -> SourceOps:
86
+ from torch.utils.cpp_extension import load
87
+
88
+ os.environ.setdefault("TORCH_CUDA_ARCH_LIST", _current_arch_list())
89
+ namespace = "fp8_gemm_source_bench"
90
+ load(
91
+ name=namespace,
92
+ sources=[
93
+ str(PACKAGE / "torch-ext" / "torch_binding.cpp"),
94
+ str(PACKAGE / "csrc" / "fp8_gemv_m1_sm120.cu"),
95
+ str(PACKAGE / "csrc" / "fp8_smallM_handtuned_sm120.cu"),
96
+ str(PACKAGE / "csrc" / "fp8_smallM_handtuned_ldmatrix_sm120.cu"),
97
+ ],
98
+ extra_include_paths=[str(PACKAGE / "csrc"), str(REGISTRATION_INCLUDE)],
99
+ extra_cflags=["-O3", "-DCUDA_KERNEL"],
100
+ extra_cuda_cflags=["-O3", "--expt-relaxed-constexpr", "-DCUDA_KERNEL"],
101
+ verbose=False,
102
+ )
103
+ return SourceOps(namespace)
104
+
105
+
106
+ def load_installed_ops(artifact: str | None):
107
+ if artifact:
108
+ sys.path.insert(0, artifact)
109
+ try:
110
+ return importlib.import_module("fp8_gemm")
111
+ finally:
112
+ if artifact:
113
+ sys.path.remove(artifact)
114
+
115
+
116
+ def select_tile(m: int, n: int, k: int, variant: int = 0) -> str:
117
+ if m == 1:
118
+ if variant == 4:
119
+ return "gemv_fp8_m1_w4"
120
+ if variant == 8:
121
+ return "gemv_fp8_m1_w8"
122
+ if variant == 16:
123
+ return "gemv_fp8_m1_w16"
124
+ if n <= 2048:
125
+ return "gemv_fp8_m1_w4"
126
+ if n <= 8192:
127
+ return "gemv_fp8_m1_w8"
128
+ return "gemv_fp8_m1_w16"
129
+ if m <= 16:
130
+ if k % 256 == 0:
131
+ return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
132
+ if n % 256 == 0:
133
+ return "ld_fp8_gemm_16x256x128_w8"
134
+ if n % 192 == 0:
135
+ return "ld_fp8_gemm_16x192x128_w4"
136
+ if n % 128 == 0:
137
+ return "ld_fp8_gemm_16x128x128_w4"
138
+ return "ld_fp8_gemm_16x64x128_w4"
139
+ if m <= 32:
140
+ if k % 256 == 0:
141
+ return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
142
+ if n % 192 == 0:
143
+ return "ld_fp8_gemm_32x192x128_w4"
144
+ if n % 128 == 0:
145
+ return "ld_fp8_gemm_32x128x128_w4"
146
+ return "ld_fp8_gemm_32x64x128_w4"
147
+ if m <= 64:
148
+ if k % 256 == 0:
149
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
150
+ if n % 128 == 0:
151
+ return "ld_fp8_gemm_64x128x128_w4"
152
+ return "ld_fp8_gemm_64x64x128_w4"
153
+ if m <= 64:
154
+ if k % 256 == 0:
155
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
156
+ if n % 128 == 0:
157
+ return "ld_fp8_gemm_64x128x128_w4"
158
+ return "ld_fp8_gemm_64x64x128_w4"
159
+ raise RuntimeError("unsupported M")
160
+
161
+
162
+ def make_inputs(m: int, k: int, n: int, seed: int):
163
+ gen = torch.Generator(device="cuda")
164
+ gen.manual_seed(seed)
165
+ x = (torch.randn((m, k), device="cuda", generator=gen) * 0.25).to(torch.bfloat16).to(torch.float8_e4m3fn)
166
+ w = (torch.randn((n, k), device="cuda", generator=gen) * 0.25).to(torch.bfloat16).to(torch.float8_e4m3fn)
167
+ return x, w
168
+
169
+
170
+ def ref_fn(x, w):
171
+ return (x.float() @ w.float().T).to(torch.bfloat16)
172
+
173
+
174
+ def measure(fn, warmup: int, iters: int) -> float:
175
+ for _ in range(warmup):
176
+ fn()
177
+ torch.cuda.synchronize()
178
+ start = torch.cuda.Event(enable_timing=True)
179
+ end = torch.cuda.Event(enable_timing=True)
180
+ start.record()
181
+ for _ in range(iters):
182
+ fn()
183
+ end.record()
184
+ torch.cuda.synchronize()
185
+ return float(start.elapsed_time(end) * 1000.0 / iters)
186
+
187
+
188
+ def metrics(got, expected):
189
+ diff = (got.float() - expected.float()).abs().flatten()
190
+ return (
191
+ float(diff.max().item()),
192
+ float(diff.mean().item()),
193
+ float(torch.quantile(diff, 0.99).item()),
194
+ float(torch.nn.functional.cosine_similarity(got.float().flatten(), expected.float().flatten(), dim=0).item()),
195
+ )
196
+
197
+
198
+ def bench_case(ops, name: str, shape: tuple[int, int, int], variant: int, warmup: int, iters: int, compile_ref: bool):
199
+ m, k, n = shape
200
+ x, w = make_inputs(m, k, n, seed=3000 + m + k + n + variant)
201
+ out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
202
+ expected = ref_fn(x, w)
203
+ got = ops.fp8_linear_bf16(x, w, out=out, variant=variant)
204
+ torch.cuda.synchronize()
205
+ max_abs, mean_abs, p99_abs, cos = metrics(got, expected)
206
+ tile = ops.select_fp8_linear_tile(m, n, k, variant)
207
+
208
+ flashrt_us = measure(lambda: ops.fp8_linear_bf16(x, w, out=out, variant=variant), warmup, iters)
209
+ eager_us = measure(lambda: ref_fn(x, w), warmup, iters)
210
+ compile_us = None
211
+ if compile_ref:
212
+ try:
213
+ compiled = torch.compile(ref_fn, fullgraph=True)
214
+ compiled(x, w)
215
+ torch.cuda.synchronize()
216
+ compile_us = measure(lambda: compiled(x, w), warmup, iters)
217
+ except Exception:
218
+ compile_us = None
219
+
220
+ return Result(
221
+ shape=name,
222
+ M=m,
223
+ K=k,
224
+ N=n,
225
+ variant=variant,
226
+ tile=tile,
227
+ flashrt_us=flashrt_us,
228
+ torch_eager_us=eager_us,
229
+ torch_compile_us=compile_us,
230
+ speedup_vs_eager=eager_us / flashrt_us,
231
+ speedup_vs_compile=(compile_us / flashrt_us) if compile_us else None,
232
+ max_abs=max_abs,
233
+ mean_abs=mean_abs,
234
+ p99_abs=p99_abs,
235
+ cosine=cos,
236
+ status="pass" if max_abs <= 0.5 and p99_abs <= 0.25 and cos >= 0.999 else "fail",
237
+ )
238
+
239
+
240
+ def main() -> None:
241
+ parser = argparse.ArgumentParser()
242
+ parser.add_argument("--backend", choices=["source", "installed"], default="source")
243
+ parser.add_argument("--artifact", default=None)
244
+ parser.add_argument("--mode", choices=sorted(MODES), default="smoke")
245
+ parser.add_argument("--warmup", type=int, default=20)
246
+ parser.add_argument("--iterations", type=int, default=100)
247
+ parser.add_argument("--compile-ref", action="store_true")
248
+ parser.add_argument("--json-out", default=None)
249
+ args = parser.parse_args()
250
+
251
+ if not torch.cuda.is_available():
252
+ raise SystemExit("CUDA is required")
253
+ major, _minor = torch.cuda.get_device_capability(0)
254
+ if major < 12:
255
+ raise SystemExit("fp8-gemm requires Blackwell/SM120 for this package")
256
+
257
+ ops = load_source_ops() if args.backend == "source" else load_installed_ops(args.artifact)
258
+ rows: list[Result] = []
259
+ for name in MODES[args.mode]:
260
+ shape = SHAPES[name]
261
+ variants = [0]
262
+ if shape[0] == 1:
263
+ variants = [0, 4, 8, 16]
264
+ for variant in variants:
265
+ rows.append(bench_case(ops, name, shape, variant, args.warmup, args.iterations, args.compile_ref))
266
+
267
+ payload = {"rows": [asdict(row) for row in rows]}
268
+ print(json.dumps(payload, indent=2, sort_keys=True))
269
+ if args.json_out:
270
+ Path(args.json_out).write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n")
271
+ if any(row.status != "pass" for row in rows):
272
+ raise SystemExit(1)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ main()
build/torch211-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT FP8 GEMM kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_bf16"))
11
+ def _fp8_linear_bf16_fake(
12
+ input: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ alpha: float,
15
+ variant: int,
16
+ out: torch.Tensor,
17
+ ) -> None:
18
+ if input.dim() != 2 or weight.dim() != 2:
19
+ raise RuntimeError("input and weight must be rank-2 tensors")
20
+ if out.shape != (input.shape[0], weight.shape[0]):
21
+ raise RuntimeError("out must have shape (input.shape[0], weight.shape[0])")
22
+ return None
23
+
24
+
25
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_residual_bf16"))
26
+ def _fp8_linear_residual_bf16_fake(
27
+ input: torch.Tensor,
28
+ weight: torch.Tensor,
29
+ alpha: float,
30
+ variant: int,
31
+ residual: torch.Tensor,
32
+ ) -> None:
33
+ if input.shape[0] != 1:
34
+ raise RuntimeError("residual path supports only M=1")
35
+ if residual.shape != (1, weight.shape[0]):
36
+ raise RuntimeError("residual must have shape (1, weight.shape[0])")
37
+ return None
38
+
39
+
40
+ def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
41
+ """Return the FlashRT tile selected by the public dispatcher."""
42
+
43
+ m = int(m)
44
+ n = int(n)
45
+ k = int(k)
46
+ variant = int(variant)
47
+ if m <= 0 or n <= 0 or k <= 0:
48
+ raise RuntimeError("m, n, and k must be positive")
49
+ if k % 32 != 0:
50
+ raise RuntimeError("k must be divisible by 32")
51
+ if m == 1:
52
+ if variant == 4:
53
+ return "gemv_fp8_m1_w4"
54
+ if variant == 8:
55
+ return "gemv_fp8_m1_w8"
56
+ if variant == 16:
57
+ return "gemv_fp8_m1_w16"
58
+ if variant != 0:
59
+ raise RuntimeError("M=1 variant must be 0, 4, 8, or 16")
60
+ if n <= 2048:
61
+ return "gemv_fp8_m1_w4"
62
+ if n <= 8192:
63
+ return "gemv_fp8_m1_w8"
64
+ return "gemv_fp8_m1_w16"
65
+ if variant != 0:
66
+ raise RuntimeError("small-M dispatcher currently supports variant=0 only")
67
+ if m <= 16:
68
+ if k % 256 == 0:
69
+ return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
70
+ if n % 256 == 0:
71
+ return "ld_fp8_gemm_16x256x128_w8"
72
+ if n % 192 == 0:
73
+ return "ld_fp8_gemm_16x192x128_w4"
74
+ if n % 128 == 0:
75
+ return "ld_fp8_gemm_16x128x128_w4"
76
+ return "ld_fp8_gemm_16x64x128_w4"
77
+ if m <= 32:
78
+ if k % 256 == 0:
79
+ return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
80
+ if n % 192 == 0:
81
+ return "ld_fp8_gemm_32x192x128_w4"
82
+ if n % 128 == 0:
83
+ return "ld_fp8_gemm_32x128x128_w4"
84
+ return "ld_fp8_gemm_32x64x128_w4"
85
+ if m <= 64:
86
+ if k % 256 == 0:
87
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
88
+ if n % 128 == 0:
89
+ return "ld_fp8_gemm_64x128x128_w4"
90
+ return "ld_fp8_gemm_64x64x128_w4"
91
+ raise RuntimeError("only M=1 decode or 2 <= M <= 64 small-M rows are supported")
92
+
93
+
94
+ def fp8_linear_bf16(
95
+ input: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ alpha: float = 1.0,
98
+ out: torch.Tensor | None = None,
99
+ variant: int = 0,
100
+ ) -> torch.Tensor:
101
+ """Compute ``(input @ weight.T) * alpha`` with BF16 output.
102
+
103
+ ``input`` and ``weight`` must be FP8 E4M3 CUDA tensors with shapes
104
+ ``(M, K)`` and ``(N, K)``. ``alpha`` is a host float, normally the product
105
+ of static per-tensor input and weight scales.
106
+ """
107
+
108
+ if out is None:
109
+ out = torch.empty(
110
+ (input.shape[0], weight.shape[0]),
111
+ device=input.device,
112
+ dtype=torch.bfloat16,
113
+ )
114
+ ops.fp8_linear_bf16(input, weight, float(alpha), int(variant), out)
115
+ return out
116
+
117
+
118
+ def fp8_linear_residual_bf16(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ residual: torch.Tensor,
122
+ alpha: float = 1.0,
123
+ variant: int = 0,
124
+ ) -> torch.Tensor:
125
+ """In-place ``residual += (input @ weight.T) * alpha`` for M=1 decode."""
126
+
127
+ ops.fp8_linear_residual_bf16(input, weight, float(alpha), int(variant), residual)
128
+ return residual
build/torch211-cxx11-cu128-x86_64-linux/_fp8_gemm_cuda_9407aee.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:132b5c476156393d09a3c0acac7b25ff7daeacd35d8e7fdbdbe5675a20142c5d
3
+ size 2458144
build/torch211-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _fp8_gemm_cuda_9407aee
3
+ ops = torch.ops._fp8_gemm_cuda_9407aee
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_fp8_gemm_cuda_9407aee::{op_name}"
build/torch211-cxx11-cu128-x86_64-linux/fp8_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "fp8-gemm",
3
+ "id": "_fp8_gemm_cuda_9407aee",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "12.0a"
11
+ ]
12
+ },
13
+ "digest": {
14
+ "algorithm": "sha256",
15
+ "files": {
16
+ "__init__.py": "Bm2+gGxw1Jrges8cKNwvxFr7dcD5K9hQbYHhO+S60ns=",
17
+ "_fp8_gemm_cuda_9407aee.abi3.so": "EytcR2FWOT0Jo8CsrHsl/32urNNdjn/b2+VnWiAULF0=",
18
+ "_ops.py": "GSkYb8wEgANAFGWUVgH9d5mYNlFxprVIHBdLjxNrwE0=",
19
+ "fp8_gemm/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY="
20
+ }
21
+ }
22
+ }
build/torch211-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT FP8 GEMM kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_bf16"))
11
+ def _fp8_linear_bf16_fake(
12
+ input: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ alpha: float,
15
+ variant: int,
16
+ out: torch.Tensor,
17
+ ) -> None:
18
+ if input.dim() != 2 or weight.dim() != 2:
19
+ raise RuntimeError("input and weight must be rank-2 tensors")
20
+ if out.shape != (input.shape[0], weight.shape[0]):
21
+ raise RuntimeError("out must have shape (input.shape[0], weight.shape[0])")
22
+ return None
23
+
24
+
25
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_residual_bf16"))
26
+ def _fp8_linear_residual_bf16_fake(
27
+ input: torch.Tensor,
28
+ weight: torch.Tensor,
29
+ alpha: float,
30
+ variant: int,
31
+ residual: torch.Tensor,
32
+ ) -> None:
33
+ if input.shape[0] != 1:
34
+ raise RuntimeError("residual path supports only M=1")
35
+ if residual.shape != (1, weight.shape[0]):
36
+ raise RuntimeError("residual must have shape (1, weight.shape[0])")
37
+ return None
38
+
39
+
40
+ def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
41
+ """Return the FlashRT tile selected by the public dispatcher."""
42
+
43
+ m = int(m)
44
+ n = int(n)
45
+ k = int(k)
46
+ variant = int(variant)
47
+ if m <= 0 or n <= 0 or k <= 0:
48
+ raise RuntimeError("m, n, and k must be positive")
49
+ if k % 32 != 0:
50
+ raise RuntimeError("k must be divisible by 32")
51
+ if m == 1:
52
+ if variant == 4:
53
+ return "gemv_fp8_m1_w4"
54
+ if variant == 8:
55
+ return "gemv_fp8_m1_w8"
56
+ if variant == 16:
57
+ return "gemv_fp8_m1_w16"
58
+ if variant != 0:
59
+ raise RuntimeError("M=1 variant must be 0, 4, 8, or 16")
60
+ if n <= 2048:
61
+ return "gemv_fp8_m1_w4"
62
+ if n <= 8192:
63
+ return "gemv_fp8_m1_w8"
64
+ return "gemv_fp8_m1_w16"
65
+ if variant != 0:
66
+ raise RuntimeError("small-M dispatcher currently supports variant=0 only")
67
+ if m <= 16:
68
+ if k % 256 == 0:
69
+ return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
70
+ if n % 256 == 0:
71
+ return "ld_fp8_gemm_16x256x128_w8"
72
+ if n % 192 == 0:
73
+ return "ld_fp8_gemm_16x192x128_w4"
74
+ if n % 128 == 0:
75
+ return "ld_fp8_gemm_16x128x128_w4"
76
+ return "ld_fp8_gemm_16x64x128_w4"
77
+ if m <= 32:
78
+ if k % 256 == 0:
79
+ return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
80
+ if n % 192 == 0:
81
+ return "ld_fp8_gemm_32x192x128_w4"
82
+ if n % 128 == 0:
83
+ return "ld_fp8_gemm_32x128x128_w4"
84
+ return "ld_fp8_gemm_32x64x128_w4"
85
+ if m <= 64:
86
+ if k % 256 == 0:
87
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
88
+ if n % 128 == 0:
89
+ return "ld_fp8_gemm_64x128x128_w4"
90
+ return "ld_fp8_gemm_64x64x128_w4"
91
+ raise RuntimeError("only M=1 decode or 2 <= M <= 64 small-M rows are supported")
92
+
93
+
94
+ def fp8_linear_bf16(
95
+ input: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ alpha: float = 1.0,
98
+ out: torch.Tensor | None = None,
99
+ variant: int = 0,
100
+ ) -> torch.Tensor:
101
+ """Compute ``(input @ weight.T) * alpha`` with BF16 output.
102
+
103
+ ``input`` and ``weight`` must be FP8 E4M3 CUDA tensors with shapes
104
+ ``(M, K)`` and ``(N, K)``. ``alpha`` is a host float, normally the product
105
+ of static per-tensor input and weight scales.
106
+ """
107
+
108
+ if out is None:
109
+ out = torch.empty(
110
+ (input.shape[0], weight.shape[0]),
111
+ device=input.device,
112
+ dtype=torch.bfloat16,
113
+ )
114
+ ops.fp8_linear_bf16(input, weight, float(alpha), int(variant), out)
115
+ return out
116
+
117
+
118
+ def fp8_linear_residual_bf16(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ residual: torch.Tensor,
122
+ alpha: float = 1.0,
123
+ variant: int = 0,
124
+ ) -> torch.Tensor:
125
+ """In-place ``residual += (input @ weight.T) * alpha`` for M=1 decode."""
126
+
127
+ ops.fp8_linear_residual_bf16(input, weight, float(alpha), int(variant), residual)
128
+ return residual
build/torch211-cxx11-cu130-x86_64-linux/_fp8_gemm_cuda_9407aee.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8727dca1fbdabfb2f0b5729c640f2f9a77ab30a47db2b2270eb1f6b21564ba73
3
+ size 2599784
build/torch211-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _fp8_gemm_cuda_9407aee
3
+ ops = torch.ops._fp8_gemm_cuda_9407aee
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_fp8_gemm_cuda_9407aee::{op_name}"
build/torch211-cxx11-cu130-x86_64-linux/fp8_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch211-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "fp8-gemm",
3
+ "id": "_fp8_gemm_cuda_9407aee",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "12.0a"
11
+ ]
12
+ },
13
+ "digest": {
14
+ "algorithm": "sha256",
15
+ "files": {
16
+ "__init__.py": "Bm2+gGxw1Jrges8cKNwvxFr7dcD5K9hQbYHhO+S60ns=",
17
+ "_fp8_gemm_cuda_9407aee.abi3.so": "hyfcofvav7LwtXKcZA8vmnerMKR9srInDrH2shVkunM=",
18
+ "_ops.py": "GSkYb8wEgANAFGWUVgH9d5mYNlFxprVIHBdLjxNrwE0=",
19
+ "fp8_gemm/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY="
20
+ }
21
+ }
22
+ }
build/torch212-cxx11-cu130-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT FP8 GEMM kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_bf16"))
11
+ def _fp8_linear_bf16_fake(
12
+ input: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ alpha: float,
15
+ variant: int,
16
+ out: torch.Tensor,
17
+ ) -> None:
18
+ if input.dim() != 2 or weight.dim() != 2:
19
+ raise RuntimeError("input and weight must be rank-2 tensors")
20
+ if out.shape != (input.shape[0], weight.shape[0]):
21
+ raise RuntimeError("out must have shape (input.shape[0], weight.shape[0])")
22
+ return None
23
+
24
+
25
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_residual_bf16"))
26
+ def _fp8_linear_residual_bf16_fake(
27
+ input: torch.Tensor,
28
+ weight: torch.Tensor,
29
+ alpha: float,
30
+ variant: int,
31
+ residual: torch.Tensor,
32
+ ) -> None:
33
+ if input.shape[0] != 1:
34
+ raise RuntimeError("residual path supports only M=1")
35
+ if residual.shape != (1, weight.shape[0]):
36
+ raise RuntimeError("residual must have shape (1, weight.shape[0])")
37
+ return None
38
+
39
+
40
+ def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
41
+ """Return the FlashRT tile selected by the public dispatcher."""
42
+
43
+ m = int(m)
44
+ n = int(n)
45
+ k = int(k)
46
+ variant = int(variant)
47
+ if m <= 0 or n <= 0 or k <= 0:
48
+ raise RuntimeError("m, n, and k must be positive")
49
+ if k % 32 != 0:
50
+ raise RuntimeError("k must be divisible by 32")
51
+ if m == 1:
52
+ if variant == 4:
53
+ return "gemv_fp8_m1_w4"
54
+ if variant == 8:
55
+ return "gemv_fp8_m1_w8"
56
+ if variant == 16:
57
+ return "gemv_fp8_m1_w16"
58
+ if variant != 0:
59
+ raise RuntimeError("M=1 variant must be 0, 4, 8, or 16")
60
+ if n <= 2048:
61
+ return "gemv_fp8_m1_w4"
62
+ if n <= 8192:
63
+ return "gemv_fp8_m1_w8"
64
+ return "gemv_fp8_m1_w16"
65
+ if variant != 0:
66
+ raise RuntimeError("small-M dispatcher currently supports variant=0 only")
67
+ if m <= 16:
68
+ if k % 256 == 0:
69
+ return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
70
+ if n % 256 == 0:
71
+ return "ld_fp8_gemm_16x256x128_w8"
72
+ if n % 192 == 0:
73
+ return "ld_fp8_gemm_16x192x128_w4"
74
+ if n % 128 == 0:
75
+ return "ld_fp8_gemm_16x128x128_w4"
76
+ return "ld_fp8_gemm_16x64x128_w4"
77
+ if m <= 32:
78
+ if k % 256 == 0:
79
+ return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
80
+ if n % 192 == 0:
81
+ return "ld_fp8_gemm_32x192x128_w4"
82
+ if n % 128 == 0:
83
+ return "ld_fp8_gemm_32x128x128_w4"
84
+ return "ld_fp8_gemm_32x64x128_w4"
85
+ if m <= 64:
86
+ if k % 256 == 0:
87
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
88
+ if n % 128 == 0:
89
+ return "ld_fp8_gemm_64x128x128_w4"
90
+ return "ld_fp8_gemm_64x64x128_w4"
91
+ raise RuntimeError("only M=1 decode or 2 <= M <= 64 small-M rows are supported")
92
+
93
+
94
+ def fp8_linear_bf16(
95
+ input: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ alpha: float = 1.0,
98
+ out: torch.Tensor | None = None,
99
+ variant: int = 0,
100
+ ) -> torch.Tensor:
101
+ """Compute ``(input @ weight.T) * alpha`` with BF16 output.
102
+
103
+ ``input`` and ``weight`` must be FP8 E4M3 CUDA tensors with shapes
104
+ ``(M, K)`` and ``(N, K)``. ``alpha`` is a host float, normally the product
105
+ of static per-tensor input and weight scales.
106
+ """
107
+
108
+ if out is None:
109
+ out = torch.empty(
110
+ (input.shape[0], weight.shape[0]),
111
+ device=input.device,
112
+ dtype=torch.bfloat16,
113
+ )
114
+ ops.fp8_linear_bf16(input, weight, float(alpha), int(variant), out)
115
+ return out
116
+
117
+
118
+ def fp8_linear_residual_bf16(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ residual: torch.Tensor,
122
+ alpha: float = 1.0,
123
+ variant: int = 0,
124
+ ) -> torch.Tensor:
125
+ """In-place ``residual += (input @ weight.T) * alpha`` for M=1 decode."""
126
+
127
+ ops.fp8_linear_residual_bf16(input, weight, float(alpha), int(variant), residual)
128
+ return residual
build/torch212-cxx11-cu130-x86_64-linux/_fp8_gemm_cuda_9407aee.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:539f23d7b7af9033ba893cb10ff99d6ad5a8f6cb00c98af42962d409ad826778
3
+ size 2610600
build/torch212-cxx11-cu130-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _fp8_gemm_cuda_9407aee
3
+ ops = torch.ops._fp8_gemm_cuda_9407aee
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_fp8_gemm_cuda_9407aee::{op_name}"
build/torch212-cxx11-cu130-x86_64-linux/fp8_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu130-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "fp8-gemm",
3
+ "id": "_fp8_gemm_cuda_9407aee",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "12.0a"
11
+ ]
12
+ },
13
+ "digest": {
14
+ "algorithm": "sha256",
15
+ "files": {
16
+ "__init__.py": "Bm2+gGxw1Jrges8cKNwvxFr7dcD5K9hQbYHhO+S60ns=",
17
+ "_fp8_gemm_cuda_9407aee.abi3.so": "U58j17evkDO6iTyxD/mdatWo9ssAyYr0KWLUCa2CZ3g=",
18
+ "_ops.py": "GSkYb8wEgANAFGWUVgH9d5mYNlFxprVIHBdLjxNrwE0=",
19
+ "fp8_gemm/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY="
20
+ }
21
+ }
22
+ }
build/torch212-cxx11-cu132-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashRT FP8 GEMM kernels."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from ._ops import add_op_namespace_prefix, ops
8
+
9
+
10
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_bf16"))
11
+ def _fp8_linear_bf16_fake(
12
+ input: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ alpha: float,
15
+ variant: int,
16
+ out: torch.Tensor,
17
+ ) -> None:
18
+ if input.dim() != 2 or weight.dim() != 2:
19
+ raise RuntimeError("input and weight must be rank-2 tensors")
20
+ if out.shape != (input.shape[0], weight.shape[0]):
21
+ raise RuntimeError("out must have shape (input.shape[0], weight.shape[0])")
22
+ return None
23
+
24
+
25
+ @torch.library.register_fake(add_op_namespace_prefix("fp8_linear_residual_bf16"))
26
+ def _fp8_linear_residual_bf16_fake(
27
+ input: torch.Tensor,
28
+ weight: torch.Tensor,
29
+ alpha: float,
30
+ variant: int,
31
+ residual: torch.Tensor,
32
+ ) -> None:
33
+ if input.shape[0] != 1:
34
+ raise RuntimeError("residual path supports only M=1")
35
+ if residual.shape != (1, weight.shape[0]):
36
+ raise RuntimeError("residual must have shape (1, weight.shape[0])")
37
+ return None
38
+
39
+
40
+ def select_fp8_linear_tile(m: int, n: int, k: int, variant: int = 0) -> str:
41
+ """Return the FlashRT tile selected by the public dispatcher."""
42
+
43
+ m = int(m)
44
+ n = int(n)
45
+ k = int(k)
46
+ variant = int(variant)
47
+ if m <= 0 or n <= 0 or k <= 0:
48
+ raise RuntimeError("m, n, and k must be positive")
49
+ if k % 32 != 0:
50
+ raise RuntimeError("k must be divisible by 32")
51
+ if m == 1:
52
+ if variant == 4:
53
+ return "gemv_fp8_m1_w4"
54
+ if variant == 8:
55
+ return "gemv_fp8_m1_w8"
56
+ if variant == 16:
57
+ return "gemv_fp8_m1_w16"
58
+ if variant != 0:
59
+ raise RuntimeError("M=1 variant must be 0, 4, 8, or 16")
60
+ if n <= 2048:
61
+ return "gemv_fp8_m1_w4"
62
+ if n <= 8192:
63
+ return "gemv_fp8_m1_w8"
64
+ return "gemv_fp8_m1_w16"
65
+ if variant != 0:
66
+ raise RuntimeError("small-M dispatcher currently supports variant=0 only")
67
+ if m <= 16:
68
+ if k % 256 == 0:
69
+ return "ld_fp8_gemm_16x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_16x64x256_w4"
70
+ if n % 256 == 0:
71
+ return "ld_fp8_gemm_16x256x128_w8"
72
+ if n % 192 == 0:
73
+ return "ld_fp8_gemm_16x192x128_w4"
74
+ if n % 128 == 0:
75
+ return "ld_fp8_gemm_16x128x128_w4"
76
+ return "ld_fp8_gemm_16x64x128_w4"
77
+ if m <= 32:
78
+ if k % 256 == 0:
79
+ return "ld_fp8_gemm_32x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_32x64x256_w4"
80
+ if n % 192 == 0:
81
+ return "ld_fp8_gemm_32x192x128_w4"
82
+ if n % 128 == 0:
83
+ return "ld_fp8_gemm_32x128x128_w4"
84
+ return "ld_fp8_gemm_32x64x128_w4"
85
+ if m <= 64:
86
+ if k % 256 == 0:
87
+ return "ld_fp8_gemm_64x128x256_w4" if n % 128 == 0 else "ld_fp8_gemm_64x64x256_w4"
88
+ if n % 128 == 0:
89
+ return "ld_fp8_gemm_64x128x128_w4"
90
+ return "ld_fp8_gemm_64x64x128_w4"
91
+ raise RuntimeError("only M=1 decode or 2 <= M <= 64 small-M rows are supported")
92
+
93
+
94
+ def fp8_linear_bf16(
95
+ input: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ alpha: float = 1.0,
98
+ out: torch.Tensor | None = None,
99
+ variant: int = 0,
100
+ ) -> torch.Tensor:
101
+ """Compute ``(input @ weight.T) * alpha`` with BF16 output.
102
+
103
+ ``input`` and ``weight`` must be FP8 E4M3 CUDA tensors with shapes
104
+ ``(M, K)`` and ``(N, K)``. ``alpha`` is a host float, normally the product
105
+ of static per-tensor input and weight scales.
106
+ """
107
+
108
+ if out is None:
109
+ out = torch.empty(
110
+ (input.shape[0], weight.shape[0]),
111
+ device=input.device,
112
+ dtype=torch.bfloat16,
113
+ )
114
+ ops.fp8_linear_bf16(input, weight, float(alpha), int(variant), out)
115
+ return out
116
+
117
+
118
+ def fp8_linear_residual_bf16(
119
+ input: torch.Tensor,
120
+ weight: torch.Tensor,
121
+ residual: torch.Tensor,
122
+ alpha: float = 1.0,
123
+ variant: int = 0,
124
+ ) -> torch.Tensor:
125
+ """In-place ``residual += (input @ weight.T) * alpha`` for M=1 decode."""
126
+
127
+ ops.fp8_linear_residual_bf16(input, weight, float(alpha), int(variant), residual)
128
+ return residual
build/torch212-cxx11-cu132-x86_64-linux/_fp8_gemm_cuda_9407aee.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e570c5e4e56bb65cb3164500b916a3c3dad695a1d08844404aa3219f8d6dc6a5
3
+ size 2622888
build/torch212-cxx11-cu132-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _fp8_gemm_cuda_9407aee
3
+ ops = torch.ops._fp8_gemm_cuda_9407aee
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_fp8_gemm_cuda_9407aee::{op_name}"
build/torch212-cxx11-cu132-x86_64-linux/fp8_gemm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import importlib.util
3
+ import sys
4
+ from pathlib import Path
5
+ from types import ModuleType
6
+
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch212-cxx11-cu132-x86_64-linux/metadata.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "fp8-gemm",
3
+ "id": "_fp8_gemm_cuda_9407aee",
4
+ "version": 1,
5
+ "license": "Apache-2.0",
6
+ "python-depends": [],
7
+ "backend": {
8
+ "type": "cuda",
9
+ "archs": [
10
+ "12.0a"
11
+ ]
12
+ },
13
+ "digest": {
14
+ "algorithm": "sha256",
15
+ "files": {
16
+ "__init__.py": "Bm2+gGxw1Jrges8cKNwvxFr7dcD5K9hQbYHhO+S60ns=",
17
+ "_fp8_gemm_cuda_9407aee.abi3.so": "5XDF5OVrtlyzFkUAuRajw9rWlaHQiERASqMhn41txqU=",
18
+ "_ops.py": "GSkYb8wEgANAFGWUVgH9d5mYNlFxprVIHBdLjxNrwE0=",
19
+ "fp8_gemm/__init__.py": "DFYPlrhXwYjEqCl/8n0SmWGZV8NFml5DPhMjKfv98GY="
20
+ }
21
+ }
22
+ }