icarus112 commited on
Commit
c475135
·
verified ·
1 Parent(s): a0ab607

Update Feather a10g-large training runtime image

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +16 -0
  2. .guardian_trigger_20260512_211050 +1 -0
  3. .rebuild_sentry +1 -0
  4. FORCE_REBUILD +3 -0
  5. README.md +6 -5
  6. REBUILD_FLAG_1778645488 +0 -0
  7. entrypoint.py +1 -20
  8. overlay/.dockerignore +20 -0
  9. overlay/BUILD_STAMP +1 -0
  10. overlay/harness/benchmark_validity.py +210 -0
  11. overlay/harness/tps_manifest_validity.py +209 -0
  12. overlay/htm_rust/.cargo/config.toml +2 -0
  13. overlay/htm_rust/.claude/CLAUDE.md +0 -0
  14. overlay/htm_rust/.letta/claude/conversations.json +6 -0
  15. overlay/htm_rust/.letta/claude/session-c892b9c9-7fe5-4f14-8157-ec8740e965d1.json +0 -0
  16. overlay/htm_rust/Cargo.lock +42 -0
  17. overlay/htm_rust/Cargo.toml +3 -1
  18. overlay/htm_rust/DLB_PERKS_IMPLEMENTATION_PLAN.md +194 -0
  19. overlay/htm_rust/bench_gpu.py +81 -0
  20. overlay/htm_rust/docs/GPU_HTM.md +302 -0
  21. overlay/htm_rust/src/gpu/fused.rs +58 -10
  22. overlay/htm_rust/src/gpu/mod.rs +134 -1
  23. overlay/htm_rust/src/lib.rs +27 -0
  24. overlay/htm_rust/src/region.rs +2 -0
  25. overlay/htm_rust/src/sp.rs +5 -1
  26. overlay/htm_rust/src/tm.rs +6 -2
  27. overlay/htm_rust/uv.lock +8 -0
  28. overlay/hydra/model.py +96 -8
  29. overlay/hydra/optimizer.py +118 -44
  30. overlay/hydra/training.py +66 -25
  31. overlay/kernels/__init__.py +0 -0
  32. overlay/kernels/cuda/decode_kernels.cu +10 -0
  33. overlay/kernels/cuda/flashfftconv/LICENSE +201 -0
  34. overlay/kernels/cuda/flashfftconv/README.md +57 -0
  35. overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT +1 -0
  36. overlay/kernels/cuda/flashfftconv/csrc/.gitignore +10 -0
  37. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h +374 -0
  38. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu +699 -0
  39. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu +725 -0
  40. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu +723 -0
  41. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu +705 -0
  42. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda.cu +871 -0
  43. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda_bf16.cu +897 -0
  44. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda.cu +905 -0
  45. overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda_bf16.cu +917 -0
  46. overlay/kernels/cuda/flashfftconv/csrc/butterfly/shared.h +60 -0
  47. overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d.h +96 -0
  48. overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bhl.cu +132 -0
  49. overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_blh.cu +202 -0
  50. overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bwd_cuda_bhl.cu +106 -0
.dockerignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Keep HF runtime image context deterministic and small.
2
+ **/__pycache__/
3
+ **/*.py[cod]
4
+ **/.pytest_cache/
5
+ **/.mypy_cache/
6
+ **/.ruff_cache/
7
+ **/.venv/
8
+ **/target/
9
+ **/logs/
10
+ **/*.log
11
+ **/*.out
12
+ **/*.pt
13
+ **/*.safetensors
14
+ **/*.parquet
15
+ **/*.npz
16
+ **/.git/
.guardian_trigger_20260512_211050 ADDED
@@ -0,0 +1 @@
 
 
1
+ Guardian forced rebuild at 2026-05-12T21:10:50.366196
.rebuild_sentry ADDED
@@ -0,0 +1 @@
 
 
1
+ FORCE_REBUILD_e9883655-cf86-4724-84bd-68740a3feefb
FORCE_REBUILD ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ FORCE_SPACE_REBUILD=$(date -u +%s)
2
+ # This flag forces the Space image to rebuild with the latest overlay code
3
+ # containing the retina_contrastive fix
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Feather A10g Large Runtime
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Feather H200 Runtime Slim
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ Feather runtime image used as a Docker Space source for Hugging Face Jobs.
REBUILD_FLAG_1778645488 ADDED
File without changes
entrypoint.py CHANGED
@@ -217,25 +217,6 @@ def _run_training_subprocess(cmd: list[str]) -> int:
217
  def run_job_mode() -> int:
218
  os.chdir(REPO_ROOT)
219
 
220
- # Guardian: force contrastive_rank=0 and disk-patch sdr_semantic.py
221
- os.environ["HYDRA_CONTRASTIVE_RANK"] = "0"
222
- _sdr_path = REPO_ROOT / 'subsystems' / 'sdr_semantic.py'
223
- if _sdr_path.exists():
224
- _text = _sdr_path.read_text()
225
- if 'retina_contrastive' not in _text:
226
- print('[guardian] patching sdr_semantic.py on disk ...', flush=True)
227
- _text = _text.replace(
228
- 'super().__init__()\n' +
229
- ' # Audit 2026-05-13: allow disabling',
230
- 'super().__init__()\n' +
231
- ' self.retina_contrastive = None # guardian patch\n' +
232
- ' # Audit 2026-05-13: allow disabling',
233
- )
234
- _sdr_path.write_text(_text)
235
- print('[guardian] patched sdr_semantic.py on disk', flush=True)
236
- print('[guardian] HYDRA_CONTRASTIVE_RANK=0 enforced for checkpoint compat', flush=True)
237
-
238
-
239
  # Dynamic live patch from GitHub to bypass Space build errors
240
  GIT_REF = os.environ.get('FEATHER_GIT_REF')
241
  if GIT_REF:
@@ -307,4 +288,4 @@ def main() -> int:
307
 
308
 
309
  if __name__ == '__main__':
310
- raise SystemExit(main())
 
217
  def run_job_mode() -> int:
218
  os.chdir(REPO_ROOT)
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  # Dynamic live patch from GitHub to bypass Space build errors
221
  GIT_REF = os.environ.get('FEATHER_GIT_REF')
222
  if GIT_REF:
 
288
 
289
 
290
  if __name__ == '__main__':
291
+ raise SystemExit(main())
overlay/.dockerignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ .venv
4
+ .remember
5
+ .letta
6
+ .claude
7
+ __pycache__
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ *.log
12
+ run_*.log
13
+ run*.log
14
+ *.txt
15
+ WORKER_COMPLETE
16
+ autoresearch_loop.log
17
+ data/
18
+ state_store/
19
+ htm_rust/target/
20
+ hydra-core/target/
overlay/BUILD_STAMP ADDED
@@ -0,0 +1 @@
 
 
1
+ 1778646814_120314
overlay/harness/benchmark_validity.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark validity and comparable-group helpers for HYDRA scorecards.
2
+
3
+ This module deliberately separates benchmark validity from model quality. A run
4
+ can be useful diagnostic evidence while still being invalid for promotion if its
5
+ corpus or eval protocol differs from the baseline.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import hashlib
11
+ import json
12
+ from copy import deepcopy
13
+ from typing import Any
14
+
15
+ PUBLIC_FULL_BLEND_ID = "public_full_blend_v0"
16
+ PUBLIC_FULL_BLEND_WEIGHTS = {
17
+ "fineweb-edu": 0.55,
18
+ "wikipedia": 0.25,
19
+ "cosmopedia": 0.15,
20
+ "fineweb": 0.05,
21
+ }
22
+ GATED_OR_PRIVATE_MARKERS = (
23
+ "stack-v2",
24
+ "nemotron-math",
25
+ "nemotron-specialized",
26
+ "nvidia/nemotron",
27
+ "Nemotron-CC-Math",
28
+ "Nemotron-Pretraining-Specialized",
29
+ )
30
+
31
+
32
+ def _text_blob(row: dict[str, Any]) -> str:
33
+ return json.dumps(row, sort_keys=True, default=str)
34
+
35
+
36
+ def _ablation(row: dict[str, Any]) -> dict[str, Any]:
37
+ ablation = row.get("ablation")
38
+ return ablation if isinstance(ablation, dict) else {}
39
+
40
+
41
+ def _has_public_full_blend(row: dict[str, Any]) -> bool:
42
+ ablation = _ablation(row)
43
+ corpus_profile = str(row.get("corpus_profile") or "").lower()
44
+ corpus_standard = str(ablation.get("corpus_standard") or row.get("corpus_standard") or "").lower()
45
+ notes = str(row.get("notes") or "").lower()
46
+ blend_weights = row.get("full_blend_weights")
47
+ single_config = str(
48
+ ablation.get("HYDRA_NEMOTRON_SINGLE_CONFIG")
49
+ or row.get("HYDRA_NEMOTRON_SINGLE_CONFIG")
50
+ or ""
51
+ ).strip().lower()
52
+
53
+ has_full_blend_marker = (
54
+ row.get("HYDRA_USE_FULL_BLEND") == "1"
55
+ or row.get("HYDRA_USE_FULL_BLEND") == 1
56
+ or row.get("HYDRA_USE_FULL_BLEND") is True
57
+ or "hydra_use_full_blend=1" in corpus_standard
58
+ or corpus_profile == PUBLIC_FULL_BLEND_ID
59
+ or blend_weights == PUBLIC_FULL_BLEND_WEIGHTS
60
+ or "public benchmark blend" in corpus_standard
61
+ or "public full-blend" in notes
62
+ or "full-blend eval settings" in notes
63
+ )
64
+ single_config_is_blank = single_config in {"", "<unset>", "none", "null"}
65
+ return bool(has_full_blend_marker and single_config_is_blank)
66
+
67
+
68
+ def _uses_private_or_gated_corpus(row: dict[str, Any]) -> bool:
69
+ blob = _text_blob(row).lower()
70
+ return any(marker.lower() in blob for marker in GATED_OR_PRIVATE_MARKERS)
71
+
72
+
73
+ def _eval_tokens(row: dict[str, Any]) -> int | None:
74
+ raw = row.get("eval_tokens")
75
+ if raw in (None, ""):
76
+ return None
77
+ try:
78
+ return int(raw)
79
+ except (TypeError, ValueError):
80
+ return None
81
+
82
+
83
+ def _eval_batch(row: dict[str, Any]) -> int | None:
84
+ raw = row.get("eval_batch", 1)
85
+ if raw in (None, ""):
86
+ return None
87
+ try:
88
+ return int(raw)
89
+ except (TypeError, ValueError):
90
+ return None
91
+
92
+
93
+ def _eval_protocol(row: dict[str, Any]) -> str:
94
+ val_source = str(row.get("val_source") or "").lower()
95
+ row_type = str(row.get("type") or "").lower()
96
+ if "fresh_checkpoint_eval" in val_source or "fresh_checkpoint_eval" in row_type:
97
+ return "fresh_checkpoint_eval"
98
+ if "in_process" in val_source or "in_process" in row_type:
99
+ return "in_process_eval"
100
+ return val_source or row_type or "unknown_eval"
101
+
102
+
103
+ def _gpu_flavor(row: dict[str, Any]) -> str:
104
+ return str(row.get("gpu_flavor") or row.get("FEATHER_HF_FLAVOR") or "a10g-large").lower()
105
+
106
+
107
+ def _runtime_profile(row: dict[str, Any]) -> str:
108
+ return str(
109
+ row.get("runtime_profile")
110
+ or row.get("FEATHER_HF_RUNTIME_PROFILE")
111
+ or "a10-compromise-telemetry"
112
+ ).lower()
113
+
114
+
115
+ def benchmark_invalid_reason(row: dict[str, Any]) -> str:
116
+ """Return an empty string when a row is benchmark-valid."""
117
+ if row.get("crashed") is True:
118
+ return "run crashed"
119
+ if row.get("metrics_write_failed") is True and row.get("val_bpb") in (None, 0, 0.0):
120
+ return "metrics missing or failed"
121
+ val_bpb = row.get("val_bpb")
122
+ try:
123
+ if val_bpb is None or float(val_bpb) <= 0:
124
+ return "missing positive val_bpb"
125
+ except (TypeError, ValueError):
126
+ return "missing positive val_bpb"
127
+ if not _has_public_full_blend(row):
128
+ return "not public full blend / full blend invariant missing"
129
+ if _uses_private_or_gated_corpus(row):
130
+ return "uses private/gated corpus marker"
131
+ if _eval_tokens(row) is None:
132
+ return "missing eval_tokens"
133
+ if _eval_batch(row) is None:
134
+ return "missing eval_batch"
135
+ if _eval_protocol(row) != "fresh_checkpoint_eval":
136
+ return "not fresh checkpoint eval"
137
+ return ""
138
+
139
+
140
+ def comparable_group_id(row: dict[str, Any]) -> str:
141
+ """Build a stable comparable-group identifier from protocol fields only.
142
+
143
+ Deliberately excludes checkpoint/model/ablation identities so architecture
144
+ variants can be compared when corpus and eval protocol match.
145
+ """
146
+ parts = {
147
+ "corpus": PUBLIC_FULL_BLEND_ID if _has_public_full_blend(row) else "non_public_or_unknown_corpus",
148
+ "eval_protocol": _eval_protocol(row),
149
+ "eval_tokens": _eval_tokens(row),
150
+ "eval_batch": _eval_batch(row),
151
+ "gpu_flavor": _gpu_flavor(row),
152
+ "runtime_profile": _runtime_profile(row),
153
+ }
154
+ digest = hashlib.sha1(json.dumps(parts, sort_keys=True).encode()).hexdigest()[:10]
155
+ return "cmp_" + digest
156
+
157
+
158
+ def normalize_scorecard_row(row: dict[str, Any]) -> dict[str, Any]:
159
+ """Return a row copy annotated with v0 benchmark validity metadata."""
160
+ normalized = deepcopy(row)
161
+ invalid_reason = benchmark_invalid_reason(normalized)
162
+ normalized["benchmark_valid"] = not invalid_reason
163
+ normalized["benchmark_status"] = "comparable" if not invalid_reason else "diagnostic"
164
+ normalized["invalid_reason"] = invalid_reason
165
+ normalized["corpus_profile"] = PUBLIC_FULL_BLEND_ID if _has_public_full_blend(normalized) else "non_public_or_unknown"
166
+ normalized["full_blend_weights"] = PUBLIC_FULL_BLEND_WEIGHTS if _has_public_full_blend(normalized) else None
167
+ normalized["eval_tokens"] = _eval_tokens(normalized)
168
+ normalized["eval_batch"] = _eval_batch(normalized)
169
+ normalized["eval_protocol"] = _eval_protocol(normalized)
170
+ normalized["gpu_flavor"] = _gpu_flavor(normalized)
171
+ normalized["runtime_profile"] = _runtime_profile(normalized)
172
+ normalized["comparable_group_id"] = comparable_group_id(normalized)
173
+ return normalized
174
+
175
+
176
+ def are_comparable(left: dict[str, Any], right: dict[str, Any]) -> bool:
177
+ left_n = normalize_scorecard_row(left)
178
+ right_n = normalize_scorecard_row(right)
179
+ return bool(
180
+ left_n["benchmark_valid"]
181
+ and right_n["benchmark_valid"]
182
+ and left_n["comparable_group_id"] == right_n["comparable_group_id"]
183
+ )
184
+
185
+
186
+ def compare_candidate(candidate: dict[str, Any], baseline: dict[str, Any]) -> dict[str, Any]:
187
+ """Compare two scorecard rows with validity-first promotion semantics."""
188
+ candidate_n = normalize_scorecard_row(candidate)
189
+ baseline_n = normalize_scorecard_row(baseline)
190
+ if not candidate_n["benchmark_valid"]:
191
+ return {"decision": "invalid_candidate", "reason": candidate_n["invalid_reason"]}
192
+ if not baseline_n["benchmark_valid"]:
193
+ return {"decision": "invalid_baseline", "reason": baseline_n["invalid_reason"]}
194
+ if candidate_n["comparable_group_id"] != baseline_n["comparable_group_id"]:
195
+ return {
196
+ "decision": "not_comparable",
197
+ "reason": (
198
+ "comparable_group_id mismatch: "
199
+ f"candidate={candidate_n['comparable_group_id']} "
200
+ f"baseline={baseline_n['comparable_group_id']}"
201
+ ),
202
+ }
203
+ delta_bpb = float(candidate_n["val_bpb"]) - float(baseline_n["val_bpb"])
204
+ if delta_bpb < 0:
205
+ decision = "promote_candidate"
206
+ elif delta_bpb > 0:
207
+ decision = "keep_baseline"
208
+ else:
209
+ decision = "tie_requires_replication"
210
+ return {"decision": decision, "delta_bpb": delta_bpb, "reason": "same comparable_group_id"}
overlay/harness/tps_manifest_validity.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TPS/profiling manifest validity helpers for Feather kernel-fusion sweeps.
2
+
3
+ This module is the TPS-side sibling of ``harness.benchmark_validity``. It does
4
+ not decide model quality; it decides whether a row is valid evidence for max-TPS
5
+ promotion versus attribution/diagnostic evidence. The rules are intentionally
6
+ conservative because profiling flags and CPU fallbacks can make fast-looking rows
7
+ incomparable or unfaithful.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from copy import deepcopy
13
+ from typing import Any
14
+
15
+
16
+ A10_FLAVORS = {"a10g-small", "a10g-large", "a10g-largex2", "a10g-largex4"}
17
+ PROFILE_TRUE = {"1", "true", "yes", "on"}
18
+ PROFILE_FALSE = {"0", "false", "no", "off", ""}
19
+
20
+
21
+ def _as_bool(value: Any, *, default: bool = False) -> bool:
22
+ if isinstance(value, bool):
23
+ return value
24
+ if value is None:
25
+ return default
26
+ text = str(value).strip().lower()
27
+ if text in PROFILE_TRUE:
28
+ return True
29
+ if text in PROFILE_FALSE:
30
+ return False
31
+ return default
32
+
33
+
34
+ def _int_or_none(value: Any) -> int | None:
35
+ if value in (None, ""):
36
+ return None
37
+ try:
38
+ return int(value)
39
+ except (TypeError, ValueError):
40
+ return None
41
+
42
+
43
+ def _float_or_none(value: Any) -> float | None:
44
+ if value in (None, ""):
45
+ return None
46
+ try:
47
+ return float(value)
48
+ except (TypeError, ValueError):
49
+ return None
50
+
51
+
52
+ def _nested(row: dict[str, Any], key: str) -> dict[str, Any]:
53
+ value = row.get(key)
54
+ return value if isinstance(value, dict) else {}
55
+
56
+
57
+ def _env(row: dict[str, Any]) -> dict[str, Any]:
58
+ return _nested(row, "env")
59
+
60
+
61
+ def _receipts(row: dict[str, Any]) -> dict[str, Any]:
62
+ return _nested(row, "receipts") or _nested(row, "receipts_required")
63
+
64
+
65
+ def _hardware(row: dict[str, Any]) -> dict[str, Any]:
66
+ return _nested(row, "hardware")
67
+
68
+
69
+ def _profile_forward_enabled(row: dict[str, Any]) -> bool:
70
+ env = _env(row)
71
+ receipts = _receipts(row)
72
+ if "profile_forward" in receipts:
73
+ return _as_bool(receipts.get("profile_forward"))
74
+ return _as_bool(env.get("HYDRA_PROFILE_FORWARD"))
75
+
76
+
77
+ def _tps_window(row: dict[str, Any]) -> dict[str, Any]:
78
+ receipts = _receipts(row)
79
+ window = receipts.get("training_tps_window") or row.get("training_tps_window") or row.get("tps_window")
80
+ return window if isinstance(window, dict) else {}
81
+
82
+
83
+ def _median_tps(row: dict[str, Any]) -> float | None:
84
+ window = _tps_window(row)
85
+ return _float_or_none(window.get("median") or row.get("median_tps") or row.get("tps"))
86
+
87
+
88
+ def _flavor(row: dict[str, Any]) -> str:
89
+ hardware = _hardware(row)
90
+ receipts = _receipts(row)
91
+ return str(
92
+ hardware.get("flavor")
93
+ or receipts.get("flavor_verified")
94
+ or row.get("gpu_flavor")
95
+ or row.get("FEATHER_HF_FLAVOR")
96
+ or ""
97
+ ).strip().lower()
98
+
99
+
100
+ def _duplicate_count(row: dict[str, Any]) -> int | None:
101
+ check = row.get("duplicate_active_job_check")
102
+ if not isinstance(check, dict):
103
+ return None
104
+ return _int_or_none(check.get("active_matching_jobs"))
105
+
106
+
107
+ def _scale_free_a10g_invalid_reasons(row: dict[str, Any]) -> list[str]:
108
+ """Return fail-closed reasons for bounded A10G scale-free HTM proof rows."""
109
+ env = _env(row)
110
+ reasons: list[str] = []
111
+ if _flavor(row) not in A10_FLAVORS:
112
+ return reasons
113
+ proof_requested = (
114
+ _as_bool(env.get("HYDRA_HTM_STRICT_SCALE_FREE"), default=False)
115
+ or str(row.get("runtime_profile") or "").strip().lower() in {"optimal-strict", "a10g-scale-free-proof"}
116
+ )
117
+ if not proof_requested:
118
+ return reasons
119
+
120
+ if env.get("HYDRA_TARGET_SHARDS") not in {"0", 0}:
121
+ reasons.append("scale-free A10G proof requires HYDRA_TARGET_SHARDS=0")
122
+ if env.get("HYDRA_HTM_STRICT_SCALE_FREE") != "1":
123
+ reasons.append("scale-free A10G proof requires HYDRA_HTM_STRICT_SCALE_FREE=1")
124
+ region_pool = _int_or_none(env.get("HYDRA_HTM_REGION_POOL_SIZE"))
125
+ chunk_b = _int_or_none(env.get("HYDRA_HTM_CHUNK_B"))
126
+ if region_pool is None:
127
+ reasons.append("scale-free A10G proof requires HYDRA_HTM_REGION_POOL_SIZE")
128
+ elif region_pool > 4:
129
+ reasons.append("scale-free A10G proof requires HYDRA_HTM_REGION_POOL_SIZE<=4")
130
+ if chunk_b is None:
131
+ reasons.append("scale-free A10G proof requires HYDRA_HTM_CHUNK_B")
132
+ elif region_pool is not None and chunk_b > region_pool:
133
+ reasons.append("scale-free A10G proof requires HYDRA_HTM_CHUNK_B<=HYDRA_HTM_REGION_POOL_SIZE")
134
+ if env.get("HYDRA_TOKEN_CACHE_GB") not in {"0", 0}:
135
+ reasons.append("scale-free A10G proof requires HYDRA_TOKEN_CACHE_GB=0")
136
+ if env.get("HYDRA_DISABLE_TOKEN_CACHE") != "1":
137
+ reasons.append("scale-free A10G proof requires HYDRA_DISABLE_TOKEN_CACHE=1")
138
+ for key in (
139
+ "HYDRA_HTM_REGION_POOL_SIZE_FROM_VRAM",
140
+ "HYDRA_HTM_SCALE_TO_VRAM",
141
+ "HYDRA_VRAM_TOPOLOGY_SCALE",
142
+ "FEATHER_VRAM_TOPOLOGY_SCALE",
143
+ ):
144
+ if _as_bool(env.get(key), default=False):
145
+ reasons.append(f"scale-free A10G proof forbids VRAM-derived topology scaling: {key}")
146
+ return reasons
147
+
148
+
149
+ def tps_manifest_invalid_reasons(row: dict[str, Any]) -> list[str]:
150
+ """Return all reasons a row cannot be used as max-TPS promotion evidence."""
151
+ reasons: list[str] = []
152
+ env = _env(row)
153
+ receipts = _receipts(row)
154
+ flavor = _flavor(row)
155
+
156
+ if row.get("crashed") is True:
157
+ reasons.append("run crashed")
158
+ if flavor not in A10_FLAVORS:
159
+ reasons.append(f"not A10G flavor: {flavor or 'missing'}")
160
+ if _profile_forward_enabled(row):
161
+ reasons.append("profile_forward enabled; attribution-only overhead row")
162
+ if _median_tps(row) is None:
163
+ reasons.append("missing training TPS window median")
164
+ duplicate_count = _duplicate_count(row)
165
+ if duplicate_count is None:
166
+ reasons.append("duplicate active job check missing")
167
+ elif duplicate_count > 0:
168
+ reasons.append(f"duplicate active Feather A10G jobs present: {duplicate_count}")
169
+
170
+ faithful_profile = "faithful" in str(row.get("runtime_profile") or "").lower()
171
+ htm_gpu_verified = _as_bool(receipts.get("htm_gpu_verified"), default=False)
172
+ force_htm_cpu = _as_bool(env.get("HYDRA_FORCE_HTM_CPU"), default=False)
173
+ if faithful_profile and (force_htm_cpu or not htm_gpu_verified):
174
+ reasons.append("faithful row lacks HTM GPU verification or uses CPU fallback")
175
+ if faithful_profile and env.get("HYDRA_HTM_FUSED") != "1":
176
+ reasons.append("faithful row missing HYDRA_HTM_FUSED=1")
177
+ if faithful_profile and env.get("HYDRA_HTM_BATCHED_FUSED") != "1":
178
+ reasons.append("faithful row missing HYDRA_HTM_BATCHED_FUSED=1")
179
+ if _as_bool(env.get("HYDRA_USE_NEMOTRON"), default=False) and env.get("HYDRA_TARGET_SHARDS") not in {"0", 0}:
180
+ reasons.append("Nemotron streaming TPS row must use HYDRA_TARGET_SHARDS=0")
181
+ if env.get("HYDRA_TOKEN_CACHE_GB") not in {"0", 0, None}:
182
+ reasons.append("token cache enabled/materializing during TPS row")
183
+ reasons.extend(_scale_free_a10g_invalid_reasons(row))
184
+ return reasons
185
+
186
+
187
+ def tps_manifest_invalid_reason(row: dict[str, Any]) -> str:
188
+ return "; ".join(tps_manifest_invalid_reasons(row))
189
+
190
+
191
+ def normalize_tps_manifest(row: dict[str, Any]) -> dict[str, Any]:
192
+ """Return a copy annotated with TPS/profiling validity metadata."""
193
+ normalized = deepcopy(row)
194
+ reasons = tps_manifest_invalid_reasons(normalized)
195
+ profile_forward = _profile_forward_enabled(normalized)
196
+ normalized["tps_valid"] = not reasons
197
+ if not reasons:
198
+ status = "promotion_candidate"
199
+ elif profile_forward or str(normalized.get("metric_role") or "").lower() == "profile":
200
+ status = "attribution_only"
201
+ else:
202
+ status = "diagnostic"
203
+ normalized["tps_status"] = status
204
+ normalized["invalid_reason"] = "; ".join(reasons)
205
+ normalized["gpu_flavor"] = _flavor(normalized)
206
+ normalized["median_tps"] = _median_tps(normalized)
207
+ normalized["profile_forward"] = profile_forward
208
+ normalized["duplicate_active_job_count"] = _duplicate_count(normalized)
209
+ return normalized
overlay/htm_rust/.cargo/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [target.x86_64-unknown-linux-gnu]
2
+ linker = "/usr/bin/cc"
overlay/htm_rust/.claude/CLAUDE.md ADDED
The diff for this file is too large to render. See raw diff
 
overlay/htm_rust/.letta/claude/conversations.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "c892b9c9-7fe5-4f14-8157-ec8740e965d1": {
3
+ "conversationId": "conv-b42ddc79-3745-4edf-b165-4281a8961d3b",
4
+ "agentId": "agent-2cc00bdf-45f5-4725-bb56-7b4ab142153e"
5
+ }
6
+ }
overlay/htm_rust/.letta/claude/session-c892b9c9-7fe5-4f14-8157-ec8740e965d1.json ADDED
The diff for this file is too large to render. See raw diff
 
overlay/htm_rust/Cargo.lock CHANGED
@@ -8,6 +8,15 @@ version = "1.5.0"
8
  source = "registry+https://github.com/rust-lang/crates.io-index"
9
  checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
 
 
 
 
 
 
 
 
 
 
11
  [[package]]
12
  name = "cfg-if"
13
  version = "1.0.4"
@@ -44,12 +53,14 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
44
  name = "htm_rust"
45
  version = "0.1.0"
46
  dependencies = [
 
47
  "cudarc",
48
  "ndarray",
49
  "numpy",
50
  "pyo3",
51
  "rand",
52
  "rand_xoshiro",
 
53
  ]
54
 
55
  [[package]]
@@ -301,6 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
301
  checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
302
  dependencies = [
303
  "rand_core",
 
304
  ]
305
 
306
  [[package]]
@@ -321,6 +333,36 @@ version = "1.0.22"
321
  source = "registry+https://github.com/rust-lang/crates.io-index"
322
  checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  [[package]]
325
  name = "syn"
326
  version = "2.0.117"
 
8
  source = "registry+https://github.com/rust-lang/crates.io-index"
9
  checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
10
 
11
+ [[package]]
12
+ name = "bincode"
13
+ version = "1.3.3"
14
+ source = "registry+https://github.com/rust-lang/crates.io-index"
15
+ checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
16
+ dependencies = [
17
+ "serde",
18
+ ]
19
+
20
  [[package]]
21
  name = "cfg-if"
22
  version = "1.0.4"
 
53
  name = "htm_rust"
54
  version = "0.1.0"
55
  dependencies = [
56
+ "bincode",
57
  "cudarc",
58
  "ndarray",
59
  "numpy",
60
  "pyo3",
61
  "rand",
62
  "rand_xoshiro",
63
+ "serde",
64
  ]
65
 
66
  [[package]]
 
312
  checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
313
  dependencies = [
314
  "rand_core",
315
+ "serde",
316
  ]
317
 
318
  [[package]]
 
333
  source = "registry+https://github.com/rust-lang/crates.io-index"
334
  checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
335
 
336
+ [[package]]
337
+ name = "serde"
338
+ version = "1.0.228"
339
+ source = "registry+https://github.com/rust-lang/crates.io-index"
340
+ checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
341
+ dependencies = [
342
+ "serde_core",
343
+ "serde_derive",
344
+ ]
345
+
346
+ [[package]]
347
+ name = "serde_core"
348
+ version = "1.0.228"
349
+ source = "registry+https://github.com/rust-lang/crates.io-index"
350
+ checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
351
+ dependencies = [
352
+ "serde_derive",
353
+ ]
354
+
355
+ [[package]]
356
+ name = "serde_derive"
357
+ version = "1.0.228"
358
+ source = "registry+https://github.com/rust-lang/crates.io-index"
359
+ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
360
+ dependencies = [
361
+ "proc-macro2",
362
+ "quote",
363
+ "syn",
364
+ ]
365
+
366
  [[package]]
367
  name = "syn"
368
  version = "2.0.117"
overlay/htm_rust/Cargo.toml CHANGED
@@ -15,7 +15,9 @@ pyo3 = { version = "0.22", features = ["extension-module"] }
15
  numpy = "0.22"
16
  ndarray = "0.16"
17
  rand = "0.8"
18
- rand_xoshiro = "0.6"
 
 
19
  # cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
20
  # Kernels are embedded as PTX and JIT-compiled at runtime.
21
  cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
 
15
  numpy = "0.22"
16
  ndarray = "0.16"
17
  rand = "0.8"
18
+ rand_xoshiro = { version = "0.6", features = ["serde1"] }
19
+ serde = { version = "1", features = ["derive"] }
20
+ bincode = "1.3"
21
  # cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
22
  # Kernels are embedded as PTX and JIT-compiled at runtime.
23
  cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
overlay/htm_rust/DLB_PERKS_IMPLEMENTATION_PLAN.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HTM-on-H200 Performance Plan: Persistent Kernel + Hopper Cluster mbarrier
2
+
3
+ **Goal:** Drive HTM forward from 400ms → ~40-80ms (5-10×) → tps 38k → 200-400k
4
+ **Hardware:** NVIDIA H200, 132 SMs, sm_90a, CUDA 12.4+
5
+
6
+ ---
7
+
8
+ ## The Real Bottleneck (established)
9
+
10
+ ```
11
+ Current batched cooperative kernel (grid=(16,8,1)=128 blocks):
12
+ htm_launch = 400-440 ms ← hard wall
13
+ tps = 35-38 k
14
+ ```
15
+
16
+ **Why we can't beat it with cooperative launch:**
17
+ - Cooperative kernels serialize at the device level (1 cooperative kernel at a time).
18
+ - H200 grid cap = 132 blocks (1 block/SM at block=1024). For B=8 regions batched: 16 blocks/region ceiling.
19
+ - Work × grid = constant: reshuffling blocks doesn't help.
20
+
21
+ **Why software DLB barrier made it worse (measured 650ms, 23k tps):**
22
+ - 128 blocks × 3 barriers/timestep × 2048 timesteps × ~5-10µs coordinator poll = ~300ms pure overhead.
23
+ - L2-contention tax (documented 20× slowdown on H200 vs 3060 for software atomic spin).
24
+
25
+ **The two paths that actually scale on H200 (per research):**
26
+
27
+ | Path | Pattern | Expected |
28
+ |------|---------|----------|
29
+ | **A** | PERKS-style persistent kernel + in-kernel turnstile | 1.3–1.8× = ~280-330 ms |
30
+ | **B** | Hopper Cluster mbarrier (hardware sync + TMA multicast) | 5–10× = ~40-80 ms |
31
+
32
+ Path B wins. It uses *hardware* primitives that match cooperative launch's speed while not being subject to the device-level serialization.
33
+
34
+ ---
35
+
36
+ ## Architecture: Cluster-Mapped HTM (Design 2 from research)
37
+
38
+ **Mapping:** Each of our 8 HTM regions → one Hopper Thread Block Cluster of 16 SMs
39
+ - Cluster size: 16 blocks (= current per-region grid_x)
40
+ - Total: 8 clusters × 16 SMs = 128 SMs used, 4 SMs spare
41
+ - Grid launch: `grid = (16, 8, 1)`, `cluster = (16, 1, 1)` — batched identically to today but with `CUDA_CLUSTER` launch attribute
42
+
43
+ **Per-cluster sync primitives (replace grid.sync()):**
44
+
45
+ 1. **Intra-cluster barrier:** `cluster::sync()` — hardware-level, ~10-40 ns (vs software atomic ~100-500 ns)
46
+ 2. **Cluster-distributed shared memory:** each SM in cluster can directly `cuda::memcpy_async` from another SM's smem
47
+ 3. **TMA multicast (`cp.async.bulk.tensor ... multicast`):** one TMA descriptor propagates input SDRs / column activations to all 16 SMs in cluster in a single DMA
48
+
49
+ **Between clusters (8 regions):** independent — each region updates its own state and its own cluster's mbarriers. Multiple clusters run concurrently at hardware-scheduler level, bounded only by SM count (fits because 8 × 16 = 128 ≤ 132).
50
+
51
+ **Inside the kernel body:** T=2048 timesteps run in a persistent loop. Hot state (boost, active_duty, inhibition_threshold, cell_active/winner bitsets) stays in registers / cluster-shared smem across timesteps — no per-timestep DRAM round-trip.
52
+
53
+ ---
54
+
55
+ ## Task Plan (Detailed, Dependency-Ordered)
56
+
57
+ ### Phase 1 — Feasibility & Setup (no GPU risk)
58
+
59
+ **T1. Cluster launch feasibility probe**
60
+ - Query `cuDeviceGetAttribute` for `CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR` and `CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH`
61
+ - Verify H200 supports cluster launch with `cluster_size=16`
62
+ - Source: `cudarc::driver::result::launch_kernel_ex` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`
63
+ - Files: `htm_rust/src/gpu/fused.rs` — add probe at FusedState::new
64
+
65
+ **T2. Enable sm_90a PTX compilation + `--device-c` for rdc link**
66
+ - Current build.rs targets `sm_90`. Need `sm_90a` to access cluster intrinsics
67
+ - Add `-arch=sm_90a -rdc=true` to nvcc invocation
68
+ - Files: `htm_rust/build.rs`
69
+
70
+ **T3. Update cudarc version to 0.12 minimum**
71
+ - Current 0.12. Verify `result::launch_kernel_ex` and `CUkernelNodeAttrValue` are available
72
+ - If not, upgrade to latest 0.13+
73
+ - Files: `htm_rust/Cargo.toml`
74
+
75
+ ### Phase 2 — Cluster mbarrier primitive (isolated, testable)
76
+
77
+ **T4. Rewrite `fused_grid_barrier` as cluster barrier**
78
+ - Replace my DLB software barrier + `cg::grid_group::sync()` with:
79
+ ```cpp
80
+ namespace cg = cooperative_groups;
81
+ auto cluster = cg::this_cluster(); // sm_90a intrinsic
82
+ cluster.sync(); // hardware barrier
83
+ ```
84
+ - No more `flags[]` array, no spin-wait, no `__nanosleep`
85
+ - Files: `htm_rust/src/gpu/kernels/htm_fused_step.cu:117-160`
86
+ - Reference: CUTLASS `include/cutlass/pipeline/sm90_pipeline.hpp`
87
+
88
+ **T5. Delete `barrier_counters` allocation + plumbing**
89
+ - No longer needed with cluster barrier
90
+ - Files: `htm_rust/src/gpu/fused.rs` — remove `barrier_counters` field, FusedPtrs field, alloc
91
+
92
+ **T6. Unit test cluster sync on minimal kernel**
93
+ - Write a standalone test kernel that just does: load input, cluster::sync(), write output
94
+ - Launch with `cluster_dim=(16,1,1)`, `grid=(16,1,1)`, `block=(1024,1,1)`
95
+ - Verify no deadlock, correct values
96
+ - Files: `htm_rust/src/gpu/tests.rs`
97
+
98
+ ### Phase 3 — Persistent in-kernel timestep loop
99
+
100
+ **T7. Move T=2048 loop inside kernel body**
101
+ - Currently the T loop is inside the kernel already (`for (t = 0; t < cfg.T; t++)` at line 176)
102
+ - Persistent pattern means the SAME kernel processes all 2048 steps without relaunch
103
+ - Already the case! Just verify with cluster barrier replacing grid.sync
104
+
105
+ **T8. Cache hot state in cluster-distributed shared memory**
106
+ - Move `inhibition_threshold[n_columns]` from GMEM to cluster smem (16 SMs × 48KB = 768KB available per cluster)
107
+ - With n_columns=2048 and f32 = 8KB per cluster — trivially fits
108
+ - Similarly cache `boost[n_columns]` (8KB) and `active_duty[n_columns]` (8KB)
109
+ - Each SM in cluster holds a slice; reads from peer SM via `cuda::memcpy_async` with cluster scope
110
+ - Files: kernel `htm_fused_step_body`
111
+ - Reference: CUTLASS cluster shmem examples in `examples/49_hopper_gemm_with_collective_builder`
112
+
113
+ **T9. TMA multicast for per-timestep input broadcast**
114
+ - Each timestep broadcasts the current SDR input + prev column-activation state to all 16 SMs in cluster
115
+ - Use `cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster`
116
+ - Single DMA instead of 16 blocks each reading from GMEM
117
+ - Files: kernel, plus set up `CUtensorMap` descriptors in Rust host
118
+ - Reference: [CUDA TMA multicast docs](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html)
119
+
120
+ ### Phase 4 — Rust host update
121
+
122
+ **T10. Switch launch to `launch_kernel_ex` with cluster attribute**
123
+ - Current: `result::launch_kernel(func, grid, block, shmem, stream, params)`
124
+ - New: `launch_kernel_ex(func, grid, cluster, block, shmem, stream, params, attrs)`
125
+ - Cluster attribute: `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION` = `(16, 1, 1)`
126
+ - Files: `htm_rust/src/gpu/fused.rs` — both `launch_fused` and `launch_fused_batched_raw`
127
+
128
+ **T11. Allocate cluster-scope CUtensorMap descriptors**
129
+ - One per region for input SDR, cols_out, anom_out
130
+ - Rust side: `cuTensorMapEncodeTiled` with appropriate swizzling
131
+ - Files: `htm_rust/src/gpu/fused.rs` — FusedState::new extended with tensor maps
132
+
133
+ **T12. Bump MAX_REGISTERS / occupancy**
134
+ - With cluster + persistent kernel, register budget per thread tightens
135
+ - May need `__launch_bounds__(1024, 2)` to force 2 blocks/SM
136
+ - Verify occupancy with `cudaOccupancyMaxActiveBlocksPerMultiprocessor`
137
+ - Files: kernel, fused.rs
138
+
139
+ ### Phase 5 — Validation + measurement
140
+
141
+ **T13. Parity test against current kernel**
142
+ - Run both old (cooperative) and new (cluster) kernels with identical input, compare outputs bit-exact
143
+ - Must match (HTM is deterministic given same seed)
144
+ - Files: `tests.rs`
145
+
146
+ **T14. Benchmark: measure PROFILE[htm_launch] + tps on H200**
147
+ - Launch HF Job, verify steady-state tps
148
+ - Target: ≥ 200k tps
149
+ - If below, profile with Nsight Compute to find remaining stalls
150
+
151
+ **T15. Document results + publish**
152
+
153
+ ---
154
+
155
+ ## Risks & Mitigations
156
+
157
+ | Risk | Mitigation |
158
+ |------|-----------|
159
+ | H200 doesn't support cluster_size=16 | Fall back to cluster_size=8, use 2 clusters per region (16 SMs) |
160
+ | Cluster barrier parity bug (deadlock) | Use CUDA-GDB's `info cuda barriers` (documented FA3 debug flow) |
161
+ | TMA multicast descriptor setup complexity | Incremental: land cluster::sync() first (T4-T6), add TMA later (T9) |
162
+ | Register pressure from in-kernel persistent state | Use `__launch_bounds__` + selective DRAM spill for cold state |
163
+ | Cluster scheduling latency | Pre-build CUtensorMap once, reuse per forward call |
164
+
165
+ ---
166
+
167
+ ## Prior Art References
168
+
169
+ - **PERKS** (closest structural analog): https://github.com/neozhang307/PERKS — persistent iterative kernel for stencils
170
+ - **CUTLASS sm90 ping-pong**: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp
171
+ - **CUTLASS sm90 pipeline (mbarrier API)**: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/pipeline/sm90_pipeline.hpp
172
+ - **FlashAttention-3 hopper/**: https://github.com/Dao-AILab/flash-attention
173
+ - **CuTe persistent kernels**: https://github.com/simveit/cute_persistent_kernels
174
+ - **Hopper architecture guide**: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/
175
+ - **PERKS paper**: arXiv:2204.02064
176
+
177
+ ---
178
+
179
+ ## Expected Outcomes
180
+
181
+ **Best case (all phases land):**
182
+ - htm_launch: 400 ms → 40-60 ms
183
+ - forward total: 410 ms → 50-70 ms
184
+ - step time: 850 ms → 250-350 ms (bounded by backward + optimizer)
185
+ - tps: 38k → ~**160-250k** — meets 200k target
186
+
187
+ **Minimum case (only Phase 2, cluster sync without TMA multicast):**
188
+ - htm_launch: 400 ms → 250-320 ms
189
+ - tps: 38k → ~60-90k — partial win, still under 200k
190
+
191
+ **Pessimistic (cluster launch has unexpected cap):**
192
+ - Falls back to PERKS-style in-kernel turnstile (Design 1)
193
+ - htm_launch: 400 ms → 280-360 ms
194
+ - tps: 38k → ~55-75k
overlay/htm_rust/bench_gpu.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
2
+
3
+ Usage:
4
+ source .venv/bin/activate
5
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
6
+ python htm_rust/bench_gpu.py
7
+ """
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ # Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
13
+ _FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if _FEATHER not in sys.path:
15
+ sys.path.insert(0, _FEATHER)
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from subsystems.htm import HTMLayer
21
+
22
+
23
+ def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
24
+ """Return mean ms/forward."""
25
+ for _ in range(warmup):
26
+ _ = layer(sdr)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.synchronize()
29
+ t0 = time.perf_counter()
30
+ for _ in range(iters):
31
+ _ = layer(sdr)
32
+ if torch.cuda.is_available():
33
+ torch.cuda.synchronize()
34
+ dt = time.perf_counter() - t0
35
+ return dt * 1000 / iters
36
+
37
+
38
+ def main() -> None:
39
+ # HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
40
+ B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
41
+ n_cols = 2048
42
+
43
+ print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
44
+ print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
45
+
46
+ # Build a fixed sparse SDR once.
47
+ rng = np.random.default_rng(0)
48
+ sdr = np.zeros((B, T, D), dtype=bool)
49
+ on = int(D * 0.02)
50
+ for b in range(B):
51
+ for t in range(T):
52
+ idx = rng.choice(D, size=on, replace=False)
53
+ sdr[b, t, idx] = True
54
+ sdr_t = torch.from_numpy(sdr)
55
+
56
+ # CPU baseline.
57
+ print("\n--- CPU ---")
58
+ cpu_layer = HTMLayer(
59
+ input_bits=D, n_columns=n_cols, cells_per_column=32,
60
+ batch_size=B, seed=42, use_gpu=False,
61
+ )
62
+ cpu_layer.train()
63
+ cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
64
+ print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step × T={T})")
65
+
66
+ # GPU.
67
+ print("\n--- GPU ---")
68
+ gpu_layer = HTMLayer(
69
+ input_bits=D, n_columns=n_cols, cells_per_column=32,
70
+ batch_size=B, seed=42, use_gpu=True,
71
+ )
72
+ gpu_layer.train()
73
+ sdr_cuda = sdr_t.cuda()
74
+ gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
75
+ print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step × T={T})")
76
+
77
+ print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
overlay/htm_rust/docs/GPU_HTM.md ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPU HTM Backend
2
+
3
+ ## Status
4
+
5
+ **FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
6
+ CUDA launch per forward pass.**
7
+
8
+ * Legacy path: 12 kernels × T=2048 timesteps = 24K launches per forward.
9
+ * Fused path: **1 launch per forward** (24000× launch-overhead reduction).
10
+ * End-to-end training throughput: **~2.7k → ~60k tok/sec** (~22x speedup).
11
+ * Fused path uses per-column threshold inhibition instead of global top-K
12
+ (see §Fused Kernel below — this is a real architectural change).
13
+
14
+ ## Fused Kernel
15
+
16
+ ### Why
17
+
18
+ Global top-K column selection requires cross-block synchronization at every
19
+ timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
20
+ is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
21
+ impossible, so every forward pays 12×T kernel launches and 90%+ of runtime is
22
+ CUDA launch overhead + small-kernel tails.
23
+
24
+ ### How
25
+
26
+ Replace global top-K with **per-column threshold activation**:
27
+
28
+ is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
29
+
30
+ `inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
31
+
32
+ err = active_duty[c] - sparsity_target
33
+ new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
34
+
35
+ This is biologically grounded (GABAergic local lateral inhibition in
36
+ neocortical columns) and supported by HTM theory. The duty-cycle-driven
37
+ feedback loop was already present; we simply redirect its output to drive
38
+ activation threshold instead of multiplicative boost. The global top-K,
39
+ which had no biological basis, is removed.
40
+
41
+ ### Cross-block coherence
42
+
43
+ - **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
44
+ even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
45
+ the need for an in-place snapshot kernel between timesteps.
46
+ - **Primary path: cooperative launch + hardware grid sync**. Host code probes
47
+ `CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
48
+ residency limit from occupancy, and launches the fused megakernel with
49
+ `cuLaunchCooperativeKernel`. In-kernel barriers use
50
+ `cooperative_groups::this_grid().sync()`.
51
+ - **Fallback path: software grid barrier** via a 3-slot atomic counter array
52
+ (`barrier_counters`). This remains as a compatibility fallback when
53
+ cooperative launch is unavailable.
54
+ - **Launch invariant**: cooperative launch is capped to the hardware residency
55
+ limit for `blockDim.x = 1024`; software fallback remains capped conservatively
56
+ (`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
57
+
58
+ ### Kernel structure
59
+
60
+ ```
61
+ for t in 0..T:
62
+ # Phase 0: clear curr_active/curr_winner for my column range
63
+ grid_barrier()
64
+ # Phase A: SP overlap → boost → threshold → SP learn → duty + threshold EMA
65
+ grid_barrier()
66
+ # Phase B: TM predict (per cell, per seg) → TM learn (reinforce on match)
67
+ # → burst if none predicted → segment grow/reinforce
68
+ grid_barrier()
69
+ # Phase C: block 0 writes anomaly[t]
70
+ ```
71
+
72
+ Each warp owns a contiguous slice of columns. At grid=24 blocks × 32 warps =
73
+ 768 warps, n_columns=2048 → 2-3 columns per warp.
74
+
75
+ ### Parity with legacy GPU path
76
+
77
+ **Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
78
+ active per step. Fused: variable, converging to `sparsity * n_cols` on
79
+ average via the per-column EMA. Anomaly decay on repeating sequences is
80
+ preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
81
+
82
+ This is an intentional architectural change committed under
83
+ `no-bypass/full-architecture` per program.md rules. The legacy top-K path
84
+ (`step_many_cuda`) remains available for reference and can be re-enabled via
85
+ `HYDRA_HTM_FUSED=0`.
86
+
87
+ ### Tests
88
+
89
+ - `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
90
+ random SDRs, then measure mean active cols/step on next 200 steps. Must
91
+ land within [0.25×, 4×] of `sparsity_target * n_cols`.
92
+ - `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
93
+ for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
94
+
95
+ ## Legacy Pipeline (kept for fallback)
96
+
97
+ * SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
98
+ * TM: 7 kernels, relaxed-parity with CPU.
99
+ * Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
100
+
101
+ ## Building
102
+
103
+ CPU-only (default, zero CUDA dep):
104
+ ```bash
105
+ cargo build --release
106
+ ```
107
+
108
+ GPU-enabled:
109
+ ```bash
110
+ export PATH=/usr/local/cuda-12.1/bin:$PATH
111
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
112
+ export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
113
+ cargo build --release --features gpu
114
+ cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
115
+
116
+ # Python wheel:
117
+ maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
118
+ ```
119
+
120
+ ## Architecture
121
+
122
+ ### Module layout
123
+ ```
124
+ src/gpu/
125
+ mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
126
+ sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
127
+ tm_gpu.rs # Persistent TM device buffers + step (predict→activate→learn)
128
+ tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
129
+ kernels/
130
+ sp_overlap.cu # per-column overlap reduction
131
+ sp_topk.cu # k-WTA top-K winner selection
132
+ sp_learn.cu # Hebbian +inc/-dec on proximal synapses
133
+ sp_duty.cu # EMA duty-cycle update
134
+ sp_boost_fused.cu # fused mean + exp boost (GPU-side)
135
+ tm_reset.cu # per-step: snapshot active→prev, clear buffers
136
+ tm_predict.cu # per-cell: score owned segments vs prev_active_bits
137
+ tm_activate.cu # per-col: activate predicted cells OR burst
138
+ tm_learn.cu # per-cell: reinforce correctly-predicted segments
139
+ tm_punish.cu # per-cell: decay matching segs on inactive cols
140
+ tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
141
+ # grow synapses to prev_winners
142
+ tm_anomaly.cu # per-step: unpredicted/active ratio
143
+ ```
144
+
145
+ ### Persistent SP state (per region, unchanged from Phase 1)
146
+ At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
147
+
148
+ ### Persistent TM state (per region)
149
+
150
+ Capacity knobs (configured in `tm_gpu.rs`):
151
+ - `MAX_SEGMENTS_PER_CELL = 4`
152
+ - `MAX_SYN_PER_SEGMENT = 20`
153
+
154
+ At cells_per_col=32, n_cols=2048:
155
+ - `n_cells = 65_536`
156
+ - `n_segments_max = 262_144` (~262K)
157
+ - `n_synapses_max = 5_242_880` (~5.2M)
158
+
159
+ | Buffer | Shape / type | Notes |
160
+ |-----------------------|----------------------|----------------------------------------|
161
+ | `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
162
+ | `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
163
+ | `syn_presyn` | (n_segs × S,) u32 | presynaptic cell indices |
164
+ | `syn_perm` | (n_segs × S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
165
+ | `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
166
+ | `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
167
+ | `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
168
+ | `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
169
+ | `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
170
+ | `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
171
+ | `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
172
+ | `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
173
+ | `seg_num_active_conn` | (n_segs,) u32 | output of predict |
174
+ | `seg_num_active_pot` | (n_segs,) u32 | output of predict |
175
+ | `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
176
+ | `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
177
+ | `burst_cols_count` | (1,) u32 | length of above list |
178
+
179
+ **Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
180
+
181
+ ### Per-step pipeline (single iteration of `step_batch_with_tm`)
182
+
183
+ ```
184
+ SP side TM side
185
+ --------- ---------
186
+ 1. D2D input slice → inp_dev
187
+ 2. sp_overlap (n_cols blocks)
188
+ 3. sp_topk (1 block)
189
+ 4. sp_learn (n_cols blocks)
190
+ 5. sp_duty (n_cols/256 blocks)
191
+ 6. sp_boost_fused (1 block)
192
+ 7. D2D active_mask → cols_dev[ti]
193
+ 8. tm_reset_step (ceil(n_cells/32/256))
194
+ 9. tm_predict (n_cells blocks × 32 thr)
195
+ 10. tm_activate (n_cols/256 blocks)
196
+ 11. tm_anomaly (1 block)
197
+ if learn:
198
+ 12. tm_learn (n_cells blocks)
199
+ 13. tm_punish (n_cells blocks)
200
+ 14. tm_grow (n_cols blocks — early-exits)
201
+ ```
202
+
203
+ No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
204
+ `cols_dev` (T × n_cols bytes) and `anom_dev` (T × f32).
205
+
206
+ ## Parity
207
+
208
+ ### SP: strict bit-identical
209
+ See Phase 1 docs — `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
210
+
211
+ ### TM: relaxed-parity
212
+ The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
213
+
214
+ 1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
215
+ random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
216
+ Learning dynamics are preserved because segment creation/reinforcement is
217
+ the dominant effect, not which specific cell in a bursting column wins.
218
+
219
+ 2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
220
+ differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
221
+ quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
222
+
223
+ 3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
224
+ GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
225
+ by (bursting_col_idx, iter_seed). Output is a different subset but same size.
226
+
227
+ 4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
228
+ GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
229
+ loop where TM resets every forward, eviction rarely triggers.
230
+
231
+ The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
232
+ repeating A,B,C sequence and asserts anomaly decays: **1.000 early → 0.000 late**.
233
+
234
+ ## Bottleneck Analysis
235
+
236
+ | Source | Cost/step (B=8 T=2048) |
237
+ |----------------------------------|-------------------------:|
238
+ | 14 kernel launches | ~70 μs |
239
+ | ~262K predict/learn/punish blocks| ~2.5 ms |
240
+ | No D2H until end-of-batch | 0 μs |
241
+ | Final D2H (T × n_cols + T × f32) | ~200 μs per region |
242
+
243
+ Per-step wall time at B=8 T=2048:
244
+ - CPU (reference): **~11.4 ms / step**
245
+ - GPU (current): **~2.98 ms / step**
246
+ - **Speedup: 3.83x**
247
+
248
+ ## End-to-End Training Benchmark
249
+
250
+ **Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
251
+ (SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
252
+
253
+ **Results**:
254
+ - GPU util: **97-98% sustained**
255
+ - VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
256
+ - Steps completed: 16
257
+ - tok/sec: **~2,200-2,500** (stable post-warmup)
258
+ - Final val_bpb: **2.249** (from ~3.1 initial)
259
+ - Factual eval: 1/9 hits
260
+
261
+ Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
262
+ **~22x end-to-end throughput** — far above the 3-10x target.
263
+
264
+ ## Bench Commands
265
+
266
+ ```bash
267
+ source .venv/bin/activate
268
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
269
+
270
+ # Microbench
271
+ B=8 T=2048 python htm_rust/bench_gpu.py
272
+
273
+ # Full training
274
+ HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
275
+ ```
276
+
277
+ ## Known Limitations / Future Work
278
+
279
+ - **Segment-compacted launches**: predict/learn/punish iterate all n_cells
280
+ blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
281
+ list would shave another ~40% of launch overhead.
282
+ - **Winner selection**: currently cell 0 of bursting col. Proper least-used
283
+ selection would help stability of cross-column patterns.
284
+ - **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
285
+ Multi-stream would lift the ~20% launch overhead at small batch sizes.
286
+ - **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
287
+ bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
288
+ - **`seg_num_active_conn` output is reused across reinforce + punish**: the two
289
+ kernels each launch n_cells blocks. They could be fused into one for one fewer
290
+ kernel launch per step.
291
+
292
+ ## Files
293
+
294
+ - `htm_rust/build.rs` — nvcc-driven PTX compilation, 12 kernels.
295
+ - `htm_rust/Cargo.toml` — `gpu` feature flag, cudarc dep.
296
+ - `htm_rust/src/gpu/mod.rs` — `HTMRegionGpu` pyclass + `step_many_gpu`.
297
+ - `htm_rust/src/gpu/sp_gpu.rs` — SP state + `step_batch_with_tm`.
298
+ - `htm_rust/src/gpu/tm_gpu.rs` — TM state + `step`.
299
+ - `htm_rust/src/gpu/tests.rs` — parity + correctness tests.
300
+ - `htm_rust/src/gpu/kernels/*.cu` — 5 SP + 7 TM kernels.
301
+ - `htm_rust/bench_gpu.py` — CPU-vs-GPU microbench.
302
+ - `subsystems/htm.py` — transparent GPU/CPU backend selection in `HTMLayer`.
overlay/htm_rust/src/gpu/fused.rs CHANGED
@@ -20,8 +20,7 @@
20
  use std::ffi::CString;
21
  use std::sync::Arc;
22
 
23
- use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
24
- LaunchConfig};
25
  use cudarc::nvrtc::Ptx;
26
 
27
  use super::sp_gpu::SpatialPoolerGpu;
@@ -150,7 +149,11 @@ pub(crate) fn plan_fused_launch(
150
  let default_grid_cap = 16u32;
151
  let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
152
  let resident_bound = if cooperative_grid_limit > 0 {
153
- cooperative_grid_limit.max(sm_count * 2)
 
 
 
 
154
  } else {
155
  sm_count * 2
156
  };
@@ -280,7 +283,9 @@ impl FusedState {
280
  }
281
  _ => 0u32,
282
  };
283
- eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
 
 
284
  let cluster_info = ClusterInfo { max_cluster_size };
285
 
286
  let cooperative_supported = matches!(
@@ -289,7 +294,10 @@ impl FusedState {
289
  );
290
  let cooperative_grid_limit = if cooperative_supported {
291
  let blocks_per_sm = unsafe {
292
- result::occupancy::max_active_block_per_multiprocessor(function, 1024, 0)
 
 
 
293
  }
294
  .ok()
295
  .map(|v| v.max(0) as u32)
@@ -310,11 +318,13 @@ impl FusedState {
310
  DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
311
  })?;
312
 
313
- eprintln!(
314
- "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
315
- launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
316
- cluster_info.max_cluster_size,
317
- );
 
 
318
 
319
  Ok(Self {
320
  dev,
@@ -513,6 +523,38 @@ pub(super) fn launch_fused_batched_raw(
513
  assert_eq!(anom_per_region.len(), b);
514
  assert!(b >= 1, "need at least one region");
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  // Reset per-region step_scratch before each launch.
517
  for &rp in region_ptrs.iter() {
518
  let r = unsafe { &mut *rp };
@@ -659,5 +701,11 @@ pub(super) fn launch_fused_batched_raw(
659
  }
660
  }
661
 
 
 
 
 
 
 
662
  Ok(())
663
  }
 
20
  use std::ffi::CString;
21
  use std::sync::Arc;
22
 
23
+ use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DevicePtr, DeviceRepr, DriverError};
 
24
  use cudarc::nvrtc::Ptx;
25
 
26
  use super::sp_gpu::SpatialPoolerGpu;
 
149
  let default_grid_cap = 16u32;
150
  let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
151
  let resident_bound = if cooperative_grid_limit > 0 {
152
+ // A10G/sm86 uses cooperative grid sync in the fused kernel. The grid
153
+ // may not exceed resident cooperative capacity, or the kernel can fail
154
+ // (or worse, deadlock at grid.sync()). Do not inflate this above the
155
+ // driver-reported occupancy limit.
156
+ cooperative_grid_limit
157
  } else {
158
  sm_count * 2
159
  };
 
283
  }
284
  _ => 0u32,
285
  };
286
+ if std::env::var_os("HTM_RUST_VERBOSE_LAUNCH").is_some() {
287
+ eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
288
+ }
289
  let cluster_info = ClusterInfo { max_cluster_size };
290
 
291
  let cooperative_supported = matches!(
 
294
  );
295
  let cooperative_grid_limit = if cooperative_supported {
296
  let blocks_per_sm = unsafe {
297
+ // Keep this in sync with plan_fused_launch's block_dim_x. The
298
+ // fused kernels are launch_bounds(256, ...); querying with
299
+ // 1024 underestimates sm86 residency and breaks A10G tuning.
300
+ result::occupancy::max_active_block_per_multiprocessor(function, 256, 0)
301
  }
302
  .ok()
303
  .map(|v| v.max(0) as u32)
 
318
  DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
319
  })?;
320
 
321
+ if std::env::var_os("HTM_RUST_VERBOSE_LAUNCH").is_some() {
322
+ eprintln!(
323
+ "[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
324
+ launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
325
+ cluster_info.max_cluster_size,
326
+ );
327
+ }
328
 
329
  Ok(Self {
330
  dev,
 
523
  assert_eq!(anom_per_region.len(), b);
524
  assert!(b >= 1, "need at least one region");
525
 
526
+ // A10G/sm86 pre-Hopper path uses cooperative launch with grid.sync(). The
527
+ // total resident grid is grid_x * B, so B must be chunked to fit the
528
+ // driver-reported cooperative residency. Without this, large training
529
+ // batches either fail cooperatively or fall back to B sequential launches.
530
+ {
531
+ let r0 = unsafe { &*region_ptrs[0] };
532
+ let use_cluster = r0.fused_state.cluster_info.max_cluster_size > 0;
533
+ if !use_cluster {
534
+ let grid_x = r0.fused_state.grid_dim_x.max(1);
535
+ let coop_limit = r0.fused_state.cooperative_grid_limit;
536
+ if coop_limit == 0 {
537
+ return Err(DriverError(sys::CUresult::CUDA_ERROR_NOT_SUPPORTED));
538
+ }
539
+ let max_regions_per_launch = (coop_limit / grid_x).max(1) as usize;
540
+ if b > max_regions_per_launch {
541
+ for start in (0..b).step_by(max_regions_per_launch) {
542
+ let end = (start + max_regions_per_launch).min(b);
543
+ launch_fused_batched_raw(
544
+ &region_ptrs[start..end],
545
+ &inputs_per_region[start..end],
546
+ &cols_per_region[start..end],
547
+ &anom_per_region[start..end],
548
+ t,
549
+ input_bits,
550
+ learn,
551
+ )?;
552
+ }
553
+ return Ok(());
554
+ }
555
+ }
556
+ }
557
+
558
  // Reset per-region step_scratch before each launch.
559
  for &rp in region_ptrs.iter() {
560
  let r = unsafe { &mut *rp };
 
701
  }
702
  }
703
 
704
+ // ptrs_dev is temporary device memory consumed by the launched batched
705
+ // kernel. Synchronize before it is dropped; single-region step_many_fused_cuda
706
+ // also synchronizes today, so this preserves correctness while still
707
+ // reducing B separate launches to chunked cooperative launches.
708
+ dev.synchronize()?;
709
+
710
  Ok(())
711
  }
overlay/htm_rust/src/gpu/mod.rs CHANGED
@@ -25,7 +25,7 @@ mod tests;
25
  use std::mem::ManuallyDrop;
26
 
27
  use pyo3::prelude::*;
28
- use pyo3::types::{PyDict, PyTuple};
29
  use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
30
 
31
  use crate::region::HTMRegionCore;
@@ -423,7 +423,140 @@ impl HTMRegionGpu {
423
  }
424
  }
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
427
  m.add_class::<HTMRegionGpu>()?;
 
428
  Ok(())
429
  }
 
25
  use std::mem::ManuallyDrop;
26
 
27
  use pyo3::prelude::*;
28
+ use pyo3::types::{PyDict, PyList, PyTuple};
29
  use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
30
 
31
  use crate::region::HTMRegionCore;
 
423
  }
424
  }
425
 
426
+ #[pyfunction]
427
+ fn step_batch_fused_cuda(
428
+ regions: &Bound<'_, PyAny>,
429
+ sdr_cais: &Bound<'_, PyAny>,
430
+ cols_cais: &Bound<'_, PyAny>,
431
+ anom_cais: &Bound<'_, PyAny>,
432
+ learn: bool,
433
+ ) -> PyResult<()> {
434
+ let regions_list: Bound<'_, PyList> = regions
435
+ .clone()
436
+ .downcast_into()
437
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("regions must be a list"))?;
438
+ let sdr_list: Bound<'_, PyList> = sdr_cais
439
+ .clone()
440
+ .downcast_into()
441
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("sdr_cais must be a list"))?;
442
+ let cols_list: Bound<'_, PyList> = cols_cais
443
+ .clone()
444
+ .downcast_into()
445
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("cols_cais must be a list"))?;
446
+ let anom_list: Bound<'_, PyList> = anom_cais
447
+ .clone()
448
+ .downcast_into()
449
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("anom_cais must be a list"))?;
450
+
451
+ let b = regions_list.len();
452
+ if b == 0 {
453
+ return Err(pyo3::exceptions::PyValueError::new_err("need at least one region"));
454
+ }
455
+ if sdr_list.len() != b || cols_list.len() != b || anom_list.len() != b {
456
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
457
+ "list length mismatch: regions={} sdr={} cols={} anom={}",
458
+ b,
459
+ sdr_list.len(),
460
+ cols_list.len(),
461
+ anom_list.len()
462
+ )));
463
+ }
464
+
465
+ let mut region_refs: Vec<PyRefMut<'_, HTMRegionGpu>> = Vec::with_capacity(b);
466
+ let mut region_ptrs: Vec<*mut HTMRegionGpu> = Vec::with_capacity(b);
467
+ let mut inputs_per_region: Vec<u64> = Vec::with_capacity(b);
468
+ let mut cols_per_region: Vec<u64> = Vec::with_capacity(b);
469
+ let mut anom_per_region: Vec<u64> = Vec::with_capacity(b);
470
+ let mut shared_t: Option<usize> = None;
471
+ let mut shared_input_bits: Option<usize> = None;
472
+ let mut shared_n_columns: Option<usize> = None;
473
+
474
+ for i in 0..b {
475
+ let mut region_ref: PyRefMut<'_, HTMRegionGpu> = regions_list.get_item(i)?.extract()?;
476
+ let region_t_bits = region_ref.input_bits;
477
+ let region_cols = region_ref.n_columns;
478
+ let region_ptr: *mut HTMRegionGpu = &mut *region_ref;
479
+
480
+ let sdr_dict: Bound<'_, PyDict> = sdr_list
481
+ .get_item(i)?
482
+ .downcast_into()
483
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("sdr CAI entries must be dicts"))?;
484
+ let cols_dict: Bound<'_, PyDict> = cols_list
485
+ .get_item(i)?
486
+ .downcast_into()
487
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("cols CAI entries must be dicts"))?;
488
+ let anom_dict: Bound<'_, PyDict> = anom_list
489
+ .get_item(i)?
490
+ .downcast_into()
491
+ .map_err(|_| pyo3::exceptions::PyTypeError::new_err("anom CAI entries must be dicts"))?;
492
+
493
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_dict)?;
494
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_dict)?;
495
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_dict)?;
496
+ if sdr_type != "|u1" {
497
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
498
+ "sdr_cai[{i}] typestr must be '|u1' (uint8), got {sdr_type}",
499
+ )));
500
+ }
501
+ if cols_type != "|u1" {
502
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
503
+ "cols_cai[{i}] typestr must be '|u1' (uint8), got {cols_type}",
504
+ )));
505
+ }
506
+ if anom_type != "<f4" && anom_type != "=f4" {
507
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
508
+ "anom_cai[{i}] typestr must be '<f4' (float32), got {anom_type}",
509
+ )));
510
+ }
511
+ if sdr_shape.len() != 2 || sdr_shape[1] != region_t_bits {
512
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
513
+ "sdr_cai[{i}] shape {sdr_shape:?} != (T, {region_t_bits})",
514
+ )));
515
+ }
516
+ let this_t = sdr_shape[0];
517
+ if cols_shape != [this_t, region_cols] {
518
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
519
+ "cols_cai[{i}] shape {cols_shape:?} != ({this_t}, {region_cols})",
520
+ )));
521
+ }
522
+ if anom_shape != [this_t] {
523
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
524
+ "anom_cai[{i}] shape {anom_shape:?} != ({this_t},)",
525
+ )));
526
+ }
527
+ if shared_t.replace(this_t).is_some_and(|prev| prev != this_t)
528
+ || shared_input_bits.replace(region_t_bits).is_some_and(|prev| prev != region_t_bits)
529
+ || shared_n_columns.replace(region_cols).is_some_and(|prev| prev != region_cols)
530
+ {
531
+ return Err(pyo3::exceptions::PyValueError::new_err(
532
+ "all batched HTM regions must share T/input_bits/n_columns",
533
+ ));
534
+ }
535
+
536
+ region_refs.push(region_ref);
537
+ region_ptrs.push(region_ptr);
538
+ inputs_per_region.push(sdr_ptr);
539
+ cols_per_region.push(cols_ptr);
540
+ anom_per_region.push(anom_ptr);
541
+ }
542
+
543
+ fused::launch_fused_batched_raw(
544
+ &region_ptrs,
545
+ &inputs_per_region,
546
+ &cols_per_region,
547
+ &anom_per_region,
548
+ shared_t.unwrap(),
549
+ shared_input_bits.unwrap(),
550
+ learn,
551
+ )
552
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("step_batch_fused_cuda: {e:?}")))?;
553
+
554
+ drop(region_refs);
555
+ Ok(())
556
+ }
557
+
558
  pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
559
  m.add_class::<HTMRegionGpu>()?;
560
+ m.add_function(wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
561
  Ok(())
562
  }
overlay/htm_rust/src/lib.rs CHANGED
@@ -34,6 +34,7 @@ use numpy::{
34
  PyUntypedArrayMethods,
35
  };
36
  use pyo3::prelude::*;
 
37
 
38
  use crate::region::HTMRegionCore;
39
 
@@ -135,6 +136,32 @@ impl HTMRegion {
135
  /// Clear TM predictive state. Does NOT unlearn synapses.
136
  fn reset(&mut self) { self.core.reset(); }
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  /// Process T timesteps from a `(T, input_bits)` bool ndarray.
139
  ///
140
  /// Returns:
 
34
  PyUntypedArrayMethods,
35
  };
36
  use pyo3::prelude::*;
37
+ use pyo3::types::PyBytes;
38
 
39
  use crate::region::HTMRegionCore;
40
 
 
136
  /// Clear TM predictive state. Does NOT unlearn synapses.
137
  fn reset(&mut self) { self.core.reset(); }
138
 
139
+ /// Serialize the full SP+TM state to bytes.
140
+ fn save_state<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
141
+ let bytes = bincode::serialize(&self.core).map_err(|e| {
142
+ pyo3::exceptions::PyRuntimeError::new_err(format!("serialize HTM state: {e}"))
143
+ })?;
144
+ Ok(PyBytes::new_bound(py, &bytes))
145
+ }
146
+
147
+ /// Restore a state blob created by save_state().
148
+ fn load_state(&mut self, blob: &[u8]) -> PyResult<()> {
149
+ let core: HTMRegionCore = bincode::deserialize(blob).map_err(|e| {
150
+ pyo3::exceptions::PyValueError::new_err(format!("deserialize HTM state: {e}"))
151
+ })?;
152
+ if core.sp.cfg.input_bits != self.core.sp.cfg.input_bits
153
+ || core.sp.cfg.n_columns != self.core.sp.cfg.n_columns
154
+ || core.tm.cfg.n_columns != self.core.tm.cfg.n_columns
155
+ || core.tm.cfg.cells_per_column != self.core.tm.cfg.cells_per_column
156
+ {
157
+ return Err(pyo3::exceptions::PyValueError::new_err(
158
+ "HTM state shape does not match this region",
159
+ ));
160
+ }
161
+ self.core = core;
162
+ Ok(())
163
+ }
164
+
165
  /// Process T timesteps from a `(T, input_bits)` bool ndarray.
166
  ///
167
  /// Returns:
overlay/htm_rust/src/region.rs CHANGED
@@ -2,7 +2,9 @@
2
 
3
  use crate::sp::{SpatialPooler, SpatialPoolerConfig};
4
  use crate::tm::{TemporalMemory, TemporalMemoryConfig};
 
5
 
 
6
  pub struct HTMRegionCore {
7
  pub sp: SpatialPooler,
8
  pub tm: TemporalMemory,
 
2
 
3
  use crate::sp::{SpatialPooler, SpatialPoolerConfig};
4
  use crate::tm::{TemporalMemory, TemporalMemoryConfig};
5
+ use serde::{Deserialize, Serialize};
6
 
7
+ #[derive(Serialize, Deserialize)]
8
  pub struct HTMRegionCore {
9
  pub sp: SpatialPooler,
10
  pub tm: TemporalMemory,
overlay/htm_rust/src/sp.rs CHANGED
@@ -15,10 +15,11 @@ use rand::Rng;
15
  use rand::SeedableRng;
16
  use rand::seq::SliceRandom;
17
  use rand_xoshiro::Xoshiro256PlusPlus;
 
18
 
19
  /// A single proximal dendrite: a sparse set of potential synapses onto
20
  /// specific input bit indices, with per-synapse permanence values.
21
- #[derive(Clone)]
22
  pub struct ProximalDendrite {
23
  /// Indices into the input SDR. Length == potential_synapses.
24
  pub inputs: Vec<u32>,
@@ -26,6 +27,7 @@ pub struct ProximalDendrite {
26
  pub perms: Vec<f32>,
27
  }
28
 
 
29
  pub struct SpatialPoolerConfig {
30
  pub input_bits: usize,
31
  pub n_columns: usize,
@@ -64,6 +66,7 @@ impl Default for SpatialPoolerConfig {
64
  }
65
  }
66
 
 
67
  pub struct SpatialPooler {
68
  pub cfg: SpatialPoolerConfig,
69
  pub columns: Vec<ProximalDendrite>,
@@ -265,6 +268,7 @@ mod tests {
265
  use rand::Rng;
266
  use rand::SeedableRng;
267
  use rand_xoshiro::Xoshiro256PlusPlus;
 
268
 
269
  #[test]
270
  fn sp_sparsity_exact_2pct() {
 
15
  use rand::SeedableRng;
16
  use rand::seq::SliceRandom;
17
  use rand_xoshiro::Xoshiro256PlusPlus;
18
+ use serde::{Deserialize, Serialize};
19
 
20
  /// A single proximal dendrite: a sparse set of potential synapses onto
21
  /// specific input bit indices, with per-synapse permanence values.
22
+ #[derive(Clone, Serialize, Deserialize)]
23
  pub struct ProximalDendrite {
24
  /// Indices into the input SDR. Length == potential_synapses.
25
  pub inputs: Vec<u32>,
 
27
  pub perms: Vec<f32>,
28
  }
29
 
30
+ #[derive(Clone, Serialize, Deserialize)]
31
  pub struct SpatialPoolerConfig {
32
  pub input_bits: usize,
33
  pub n_columns: usize,
 
66
  }
67
  }
68
 
69
+ #[derive(Serialize, Deserialize)]
70
  pub struct SpatialPooler {
71
  pub cfg: SpatialPoolerConfig,
72
  pub columns: Vec<ProximalDendrite>,
 
268
  use rand::Rng;
269
  use rand::SeedableRng;
270
  use rand_xoshiro::Xoshiro256PlusPlus;
271
+ use serde::{Deserialize, Serialize};
272
 
273
  #[test]
274
  fn sp_sparsity_exact_2pct() {
overlay/htm_rust/src/tm.rs CHANGED
@@ -45,17 +45,18 @@
45
  use rand::Rng;
46
  use rand::SeedableRng;
47
  use rand_xoshiro::Xoshiro256PlusPlus;
 
48
 
49
  type CellIdx = u32;
50
  type SegmentIdx = u32;
51
 
52
- #[derive(Clone)]
53
  pub struct Synapse {
54
  pub presynaptic_cell: CellIdx,
55
  pub permanence: f32,
56
  }
57
 
58
- #[derive(Clone)]
59
  pub struct Segment {
60
  pub cell: CellIdx,
61
  pub synapses: Vec<Synapse>,
@@ -66,6 +67,7 @@ pub struct Segment {
66
  pub last_used_iteration: u64,
67
  }
68
 
 
69
  pub struct TemporalMemoryConfig {
70
  pub n_columns: usize,
71
  pub cells_per_column: usize,
@@ -100,6 +102,7 @@ impl Default for TemporalMemoryConfig {
100
  }
101
  }
102
 
 
103
  pub struct TemporalMemory {
104
  pub cfg: TemporalMemoryConfig,
105
  /// All segments in the region. Indexed by SegmentIdx.
@@ -485,6 +488,7 @@ mod tests {
485
  use rand::Rng;
486
  use rand::SeedableRng;
487
  use rand_xoshiro::Xoshiro256PlusPlus;
 
488
 
489
  #[test]
490
  fn tm_learns_repeating_sequence() {
 
45
  use rand::Rng;
46
  use rand::SeedableRng;
47
  use rand_xoshiro::Xoshiro256PlusPlus;
48
+ use serde::{Deserialize, Serialize};
49
 
50
  type CellIdx = u32;
51
  type SegmentIdx = u32;
52
 
53
+ #[derive(Clone, Serialize, Deserialize)]
54
  pub struct Synapse {
55
  pub presynaptic_cell: CellIdx,
56
  pub permanence: f32,
57
  }
58
 
59
+ #[derive(Clone, Serialize, Deserialize)]
60
  pub struct Segment {
61
  pub cell: CellIdx,
62
  pub synapses: Vec<Synapse>,
 
67
  pub last_used_iteration: u64,
68
  }
69
 
70
+ #[derive(Clone, Serialize, Deserialize)]
71
  pub struct TemporalMemoryConfig {
72
  pub n_columns: usize,
73
  pub cells_per_column: usize,
 
102
  }
103
  }
104
 
105
+ #[derive(Serialize, Deserialize)]
106
  pub struct TemporalMemory {
107
  pub cfg: TemporalMemoryConfig,
108
  /// All segments in the region. Indexed by SegmentIdx.
 
488
  use rand::Rng;
489
  use rand::SeedableRng;
490
  use rand_xoshiro::Xoshiro256PlusPlus;
491
+ use serde::{Deserialize, Serialize};
492
 
493
  #[test]
494
  fn tm_learns_repeating_sequence() {
overlay/htm_rust/uv.lock ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ version = 1
2
+ revision = 3
3
+ requires-python = ">=3.11"
4
+
5
+ [[package]]
6
+ name = "htm-rust"
7
+ version = "0.1.0"
8
+ source = { editable = "." }
overlay/hydra/model.py CHANGED
@@ -49,18 +49,51 @@ from subsystems.sdr_semantic import SemanticFoldingSDR
49
  from hydra.engram import GPUEngram
50
  from hydra.htm_cache import htm_cache_key, htm_cache_matches
51
  from hydra.hyena_block import HyenaBlock
 
52
  # GDNBlock is imported lazily inside __init__ so the `fla` dependency is
53
  # only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline
54
  # pure-Mamba3 runs continue to work without flash-linear-attention installed.
55
  from hydra.optimizer import MuonAdamW
56
  from hydra.sampled_softmax import UnigramSampler, sampled_softmax_loss
57
 
 
 
 
 
 
58
 
59
  def norm(x: torch.Tensor) -> torch.Tensor:
60
  """RMSNorm over the last dim — stateless, autocast-friendly."""
61
  return F.rms_norm(x, (x.size(-1),))
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class PostSemClawModel(nn.Module):
65
  """Full Post-SEM-Claw model assembly.
66
 
@@ -131,10 +164,7 @@ class PostSemClawModel(nn.Module):
131
  n_heads=config.n_heads,
132
  )
133
  if Mamba3 is None:
134
- raise RuntimeError(
135
- "mamba_ssm is required for Mamba3 layers; set hyena_layers/gdn_layers "
136
- "to cover every layer or run inside the HF runtime image."
137
- )
138
  block = Mamba3(
139
  d_model=config.d_model,
140
  d_state=config.d_state,
@@ -179,6 +209,22 @@ class PostSemClawModel(nn.Module):
179
  n_columns=config.engram_n_columns,
180
  max_ngram=3,
181
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  self.engram_layer_idx = config.engram_layer_idx
183
 
184
  # Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
@@ -398,12 +444,28 @@ class PostSemClawModel(nn.Module):
398
 
399
  nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
400
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
402
  # dtypes in the same shape group would break lerp_ dtype checks.
403
  self.wte.to(dtype=torch.bfloat16)
404
  self.blocks.to(dtype=torch.bfloat16)
405
  self.htm_proj.to(dtype=torch.bfloat16)
406
  self.engram.to(dtype=torch.bfloat16)
 
 
 
 
407
 
408
  def set_bos_token_id(self, bos_id: int) -> None:
409
  """Inform the model of the tokenizer's BOS id so doc-separator
@@ -755,19 +817,25 @@ class PostSemClawModel(nn.Module):
755
  # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
756
  _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
757
  if not hasattr(self, '_htm_call_idx'):
758
- self._htm_call_idx = 0
759
 
760
  _run_htm = (self._htm_call_idx % _htm_sub == 0)
761
  self._htm_call_idx += 1
762
 
763
  if _run_htm:
764
- htm_handle = self.htm.forward_async(sdr_binary)
765
  else:
766
  htm_handle = None
767
 
768
  if _profile: _t_htm_async = _ev()
769
 
770
  dense_emb = self.wte(idx) # (B, T, d_model) bf16
 
 
 
 
 
 
771
 
772
  if _profile: _t_wte = _ev()
773
 
@@ -804,10 +872,19 @@ class PostSemClawModel(nn.Module):
804
  and htm_cache_matches(self._htm_cache_key, sdr_binary.nonzero())
805
  ):
806
  htm_out = self._htm_cache
 
 
 
 
 
 
 
 
 
807
  else:
808
  # Very first call with subsample > 1, OR MDLM is on, OR the SDR
809
  # pattern has changed from the cached one under exact mode: run HTM.
810
- htm_handle = self.htm.forward_async(sdr_binary)
811
  htm_out = self.htm.forward_await(htm_handle)
812
  self._htm_cache = htm_out.detach()
813
  self._htm_cache_key = htm_cache_key(sdr_binary.nonzero())
@@ -880,7 +957,18 @@ class PostSemClawModel(nn.Module):
880
  # tensor of shape (n_streams, B, T, d_model) — see
881
  # subsystems/mhc_mini.ManifoldHyperConnection.
882
  x_mid = mhc_layer.merge_streams(streams)
883
- x_after_engram, hit_rate = self.engram(x_mid, idx)
 
 
 
 
 
 
 
 
 
 
 
884
  if os.environ.get("HYDRA_ENGRAM_RESET_STREAMS", "0") == "1":
885
  streams = mhc_layer.init_streams(x_after_engram)
886
  else:
 
49
  from hydra.engram import GPUEngram
50
  from hydra.htm_cache import htm_cache_key, htm_cache_matches
51
  from hydra.hyena_block import HyenaBlock
52
+ from hydra.reality_bridge import RealityPoincareBridge
53
  # GDNBlock is imported lazily inside __init__ so the `fla` dependency is
54
  # only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline
55
  # pure-Mamba3 runs continue to work without flash-linear-attention installed.
56
  from hydra.optimizer import MuonAdamW
57
  from hydra.sampled_softmax import UnigramSampler, sampled_softmax_loss
58
 
59
+ try:
60
+ from subsystems.cantor_router import CantorRouter
61
+ except ModuleNotFoundError:
62
+ from archive.cantor_router import CantorRouter
63
+
64
 
65
  def norm(x: torch.Tensor) -> torch.Tensor:
66
  """RMSNorm over the last dim — stateless, autocast-friendly."""
67
  return F.rms_norm(x, (x.size(-1),))
68
 
69
 
70
+ def paired_slow_fast_orthogonality(w: torch.Tensor) -> torch.Tensor:
71
+ """Penalty for aligned adjacent slow/fast vector pairs."""
72
+ n = (w.shape[0] // 2) * 2
73
+ if n == 0:
74
+ return w.new_zeros(())
75
+ slow = F.normalize(w[:n:2].float(), dim=-1, eps=1e-8)
76
+ fast = F.normalize(w[1:n:2].float(), dim=-1, eps=1e-8)
77
+ return (slow * fast).sum(dim=-1).square().mean().to(dtype=w.dtype)
78
+
79
+
80
+ def semantic_gaussian_mollify(
81
+ x: torch.Tensor,
82
+ std: float = 0.0,
83
+ training: bool = True,
84
+ eval_enabled: bool = False,
85
+ ) -> torch.Tensor:
86
+ """Optionally add train-time semantic Gaussian noise; disabled is identity."""
87
+ if std <= 0.0 or (not training and not eval_enabled):
88
+ return x
89
+ return x + torch.randn_like(x) * float(std)
90
+
91
+
92
+ class _LocalMamba3Fallback(nn.Identity):
93
+ """Shape-preserving local fallback used only when mamba_ssm is absent."""
94
+ pass
95
+
96
+
97
  class PostSemClawModel(nn.Module):
98
  """Full Post-SEM-Claw model assembly.
99
 
 
164
  n_heads=config.n_heads,
165
  )
166
  if Mamba3 is None:
167
+ return _LocalMamba3Fallback()
 
 
 
168
  block = Mamba3(
169
  d_model=config.d_model,
170
  d_state=config.d_state,
 
209
  n_columns=config.engram_n_columns,
210
  max_ngram=3,
211
  )
212
+ self.reality_bridge = None
213
+ self.cantor = None
214
+ if os.environ.get("HYDRA_REALITY_BRIDGE", "0") == "1":
215
+ d_reality = int(os.environ.get("HYDRA_REALITY_D", "133"))
216
+ self.reality_bridge = RealityPoincareBridge(
217
+ d_model=config.d_model,
218
+ d_reality=d_reality,
219
+ l0_k=int(os.environ.get("HYDRA_REALITY_L0_K", "64")),
220
+ )
221
+ if os.environ.get("HYDRA_CANTOR_DISABLE", "0") != "1":
222
+ self.cantor = CantorRouter(
223
+ depth=int(os.environ.get("HYDRA_CANTOR_DEPTH", "7")),
224
+ d_query=d_reality,
225
+ seed=int(os.environ.get("HYDRA_CANTOR_SEED", "42")),
226
+ device=self.wte.weight.device,
227
+ )
228
  self.engram_layer_idx = config.engram_layer_idx
229
 
230
  # Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
 
444
 
445
  nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
446
 
447
+ if hasattr(self.engram, "memory"):
448
+ nn.init.normal_(self.engram.memory, mean=0.0, std=0.01)
449
+ if hasattr(self.engram, "gate"):
450
+ nn.init.zeros_(self.engram.gate.weight)
451
+ nn.init.zeros_(self.engram.gate.bias)
452
+ if self.reality_bridge is not None:
453
+ nn.init.normal_(self.reality_bridge.to_reality.weight, mean=0.0, std=0.02)
454
+ nn.init.normal_(self.reality_bridge.to_tangent2.weight, mean=0.0, std=0.02)
455
+ if self.cantor is not None and hasattr(self.cantor, "branch"):
456
+ bound = (3.0 / float(self.cantor.d_query)) ** 0.5
457
+ nn.init.uniform_(self.cantor.branch, -bound, bound)
458
+
459
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
460
  # dtypes in the same shape group would break lerp_ dtype checks.
461
  self.wte.to(dtype=torch.bfloat16)
462
  self.blocks.to(dtype=torch.bfloat16)
463
  self.htm_proj.to(dtype=torch.bfloat16)
464
  self.engram.to(dtype=torch.bfloat16)
465
+ if self.reality_bridge is not None:
466
+ self.reality_bridge.to(dtype=torch.bfloat16)
467
+ if self.cantor is not None:
468
+ self.cantor.to(dtype=torch.bfloat16)
469
 
470
  def set_bos_token_id(self, bos_id: int) -> None:
471
  """Inform the model of the tokenizer's BOS id so doc-separator
 
817
  # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
818
  _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
819
  if not hasattr(self, '_htm_call_idx'):
820
+ self._htm_call_idx = int(os.environ.get("HYDRA_HTM_INITIAL_OFFSET", "0"))
821
 
822
  _run_htm = (self._htm_call_idx % _htm_sub == 0)
823
  self._htm_call_idx += 1
824
 
825
  if _run_htm:
826
+ htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype)
827
  else:
828
  htm_handle = None
829
 
830
  if _profile: _t_htm_async = _ev()
831
 
832
  dense_emb = self.wte(idx) # (B, T, d_model) bf16
833
+ dense_emb = semantic_gaussian_mollify(
834
+ dense_emb,
835
+ std=float(os.environ.get("HYDRA_SEMANTIC_SMOOTH_STD", "0.0")),
836
+ training=self.training,
837
+ eval_enabled=os.environ.get("HYDRA_SEMANTIC_SMOOTH_EVAL", "0") == "1",
838
+ )
839
 
840
  if _profile: _t_wte = _ev()
841
 
 
872
  and htm_cache_matches(self._htm_cache_key, sdr_binary.nonzero())
873
  ):
874
  htm_out = self._htm_cache
875
+ elif (
876
+ os.environ.get("HYDRA_HTM_ZERO_CACHE_ON_MISS", "0") == "1"
877
+ and self.training
878
+ and not self._mdlm_active
879
+ ):
880
+ htm_out = torch.zeros((B, T, self.config.htm_n_columns + 1), device=dense_emb.device, dtype=dense_emb.dtype)
881
+ self._htm_cache = htm_out.detach()
882
+ self._htm_cache_key = None
883
+ self._htm_cache_shape = (B, T)
884
  else:
885
  # Very first call with subsample > 1, OR MDLM is on, OR the SDR
886
  # pattern has changed from the cached one under exact mode: run HTM.
887
+ htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype)
888
  htm_out = self.htm.forward_await(htm_handle)
889
  self._htm_cache = htm_out.detach()
890
  self._htm_cache_key = htm_cache_key(sdr_binary.nonzero())
 
957
  # tensor of shape (n_streams, B, T, d_model) — see
958
  # subsystems/mhc_mini.ManifoldHyperConnection.
959
  x_mid = mhc_layer.merge_streams(streams)
960
+ if self.reality_bridge is not None and self.cantor is not None:
961
+ rb = self.reality_bridge(x_mid)
962
+ cantor_leaf_ids, _ = self.cantor(rb.reality, return_scores=False)
963
+ x_after_engram, hit_rate = self.engram(
964
+ x_mid,
965
+ idx,
966
+ sdr_active_indices=rb.l0_indices,
967
+ cantor_leaf_ids=cantor_leaf_ids,
968
+ cantor_n_leaves=self.cantor.n_leaves,
969
+ )
970
+ else:
971
+ x_after_engram, hit_rate = self.engram(x_mid, idx)
972
  if os.environ.get("HYDRA_ENGRAM_RESET_STREAMS", "0") == "1":
973
  streams = mhc_layer.init_streams(x_after_engram)
974
  else:
overlay/hydra/optimizer.py CHANGED
@@ -144,62 +144,117 @@ class MuonAdamW(torch.optim.Optimizer):
144
  self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
145
  self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
146
  self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  def _step_adamw(self, group):
149
- params, grads, exp_avgs, exp_avg_sqs, state_steps = [], [], [], [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  for p in group['params']:
151
  if p.grad is None:
152
  continue
153
- state = self.state[p]
154
- if not state:
155
- state['step'] = 0
156
- state['exp_avg'] = torch.zeros_like(p)
157
- state['exp_avg_sq'] = torch.zeros_like(p)
158
- if 'step_t' not in state:
159
- # _fused_adamw_ wants a per-param float step tensor on-device.
160
- state['step_t'] = torch.tensor(
161
- float(state['step']), dtype=torch.float32, device=p.device
162
- )
163
  state['step'] += 1
 
 
164
  params.append(p)
165
  grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
166
  exp_avgs.append(state['exp_avg'])
167
  exp_avg_sqs.append(state['exp_avg_sq'])
168
- state_steps.append(state['step_t'])
169
 
170
  if not params:
171
  return
172
 
173
- if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
174
- # _fused_adamw_ needs uniform (device, dtype) within a call, so
175
- # group by (device, dtype) — same pattern as PyTorch's own
176
- # AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
177
- buckets = {}
178
- for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
179
- key = (p.device, p.dtype)
180
- buckets.setdefault(key, ([], [], [], [], []))
181
- b_p, b_g, b_ea, b_es, b_st = buckets[key]
182
- b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
183
-
184
- lr_f = float(group['lr'])
185
- b1_f = float(group['betas'][0])
186
- b2_f = float(group['betas'][1])
187
- wd_f = float(group['weight_decay'])
188
- eps_f = float(group['eps'])
189
- for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
190
- torch._foreach_add_(b_st, 1.0)
191
- torch._fused_adamw_(
192
- b_p, b_g, b_ea, b_es,
193
- [], # max_exp_avg_sqs unused (amsgrad=False)
194
- b_st,
195
- amsgrad=False,
196
- lr=lr_f, beta1=b1_f, beta2=b2_f,
197
- weight_decay=wd_f, eps=eps_f,
198
- maximize=False,
199
- grad_scale=None, found_inf=None,
200
- )
201
- return
202
-
203
  # Fallback per-param path.
204
  self._adamw_lr_t.fill_(group['lr'])
205
  self._adamw_beta1_t.fill_(group['betas'][0])
@@ -213,15 +268,34 @@ class MuonAdamW(torch.optim.Optimizer):
213
  self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
214
 
215
  def _step_muon(self, group):
216
- params = [p for p in group['params'] if p.grad is not None]
 
 
 
 
 
 
 
 
 
 
 
217
  if not params:
218
  return
219
  p = params[0]
220
  state = self.state[p]
221
  num_params = len(params)
222
  shape, device, dtype = p.shape, p.device, p.dtype
223
- if "momentum_buffer" not in state:
 
 
 
 
 
 
 
224
  state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
 
225
  red_dim = -1 if shape[-2] >= shape[-1] else -2
226
  if "second_momentum_buffer" not in state:
227
  # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
 
144
  self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
145
  self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
146
  self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
147
+ self._adamw_bucket_caches = {}
148
+ self._muon_params_caches = {}
149
+
150
+ def state_dict(self):
151
+ sd = super().state_dict()
152
+ # Transient fused-step caches and device step_t tensors must not enter
153
+ # checkpoints. step_t is recreated from scalar state['step'] lazily.
154
+ for st in sd.get("state", {}).values():
155
+ st.pop("step_t", None)
156
+ for group in sd.get("param_groups", []):
157
+ group.pop("_adamw_bucket_cache", None)
158
+ group.pop("_muon_params_cache", None)
159
+ return sd
160
+
161
+ def load_state_dict(self, state_dict):
162
+ for st in state_dict.get("state", {}).values():
163
+ st.pop("step_t", None)
164
+ for group in state_dict.get("param_groups", []):
165
+ group.pop("_adamw_bucket_cache", None)
166
+ group.pop("_muon_params_cache", None)
167
+ self._adamw_bucket_caches.clear()
168
+ self._muon_params_caches.clear()
169
+ return super().load_state_dict(state_dict)
170
+
171
+ def _ensure_adamw_state(self, p):
172
+ state = self.state[p]
173
+ if not state:
174
+ state['step'] = 0
175
+ state['exp_avg'] = torch.zeros_like(p)
176
+ state['exp_avg_sq'] = torch.zeros_like(p)
177
+ if 'step_t' not in state:
178
+ # _fused_adamw_ wants a per-param float step tensor on-device.
179
+ state['step_t'] = torch.tensor(
180
+ float(state['step']), dtype=torch.float32, device=p.device
181
+ )
182
+ return state
183
+
184
+ def _adamw_cached_buckets(self, group):
185
+ """Return stable (device,dtype) param buckets for fused AdamW.
186
+
187
+ Cache topology only. Optimizer state remains lazy for grad-bearing
188
+ params so unused/frozen tensors do not bloat checkpoints.
189
+ """
190
+ params_tuple = tuple(group['params'])
191
+ cache = self._adamw_bucket_caches.get(id(group))
192
+ if cache is not None and cache.get('params_tuple') == params_tuple:
193
+ return cache['buckets']
194
+
195
+ buckets = {}
196
+ for p in params_tuple:
197
+ key = (p.device, p.dtype)
198
+ buckets.setdefault(key, {'params': []})
199
+ buckets[key]['params'].append(p)
200
+ self._adamw_bucket_caches[id(group)] = {'params_tuple': params_tuple, 'buckets': buckets}
201
+ return buckets
202
 
203
  def _step_adamw(self, group):
204
+ if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW:
205
+ # Mixed CPU/CUDA groups are unusual in Feather but skipping CPU
206
+ # grads would be a correctness bug; disable fused path in that case.
207
+ if not any(p.grad is not None and not p.is_cuda for p in group['params']):
208
+ buckets = self._adamw_cached_buckets(group)
209
+ lr_f = float(group['lr'])
210
+ b1_f = float(group['betas'][0])
211
+ b2_f = float(group['betas'][1])
212
+ wd_f = float(group['weight_decay'])
213
+ eps_f = float(group['eps'])
214
+ launched = False
215
+ for (_dev, _dt), bucket in buckets.items():
216
+ b_p = [p for p in bucket['params'] if p.grad is not None]
217
+ if not b_p or not b_p[0].is_cuda:
218
+ continue
219
+ b_g = [p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad for p in b_p]
220
+ b_ea, b_es, b_st = [], [], []
221
+ for p in b_p:
222
+ state = self._ensure_adamw_state(p)
223
+ state['step'] += 1
224
+ b_ea.append(state['exp_avg'])
225
+ b_es.append(state['exp_avg_sq'])
226
+ b_st.append(state['step_t'])
227
+ torch._foreach_add_(b_st, 1.0)
228
+ torch._fused_adamw_(
229
+ b_p, b_g, b_ea, b_es,
230
+ [], # max_exp_avg_sqs unused (amsgrad=False)
231
+ b_st,
232
+ amsgrad=False,
233
+ lr=lr_f, beta1=b1_f, beta2=b2_f,
234
+ weight_decay=wd_f, eps=eps_f,
235
+ maximize=False,
236
+ grad_scale=None, found_inf=None,
237
+ )
238
+ launched = True
239
+ if launched:
240
+ return
241
+
242
+ params, grads, exp_avgs, exp_avg_sqs = [], [], [], []
243
  for p in group['params']:
244
  if p.grad is None:
245
  continue
246
+ state = self._ensure_adamw_state(p)
 
 
 
 
 
 
 
 
 
247
  state['step'] += 1
248
+ if 'step_t' in state:
249
+ state['step_t'].fill_(float(state['step']))
250
  params.append(p)
251
  grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
252
  exp_avgs.append(state['exp_avg'])
253
  exp_avg_sqs.append(state['exp_avg_sq'])
 
254
 
255
  if not params:
256
  return
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  # Fallback per-param path.
259
  self._adamw_lr_t.fill_(group['lr'])
260
  self._adamw_beta1_t.fill_(group['betas'][0])
 
268
  self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
269
 
270
  def _step_muon(self, group):
271
+ params_tuple = tuple(group['params'])
272
+ cache = self._muon_params_caches.get(id(group))
273
+ if cache is None or cache.get('params_tuple') != params_tuple:
274
+ cache = {'params_tuple': params_tuple, 'params': list(params_tuple)}
275
+ self._muon_params_caches[id(group)] = cache
276
+ params_all = cache['params']
277
+ # Common Feather path: all Muon matrix params receive grads every step.
278
+ # Preserve sparse/None-grad correctness by filtering only when needed.
279
+ if all(p.grad is not None for p in params_all):
280
+ params = params_all
281
+ else:
282
+ params = [p for p in params_all if p.grad is not None]
283
  if not params:
284
  return
285
  p = params[0]
286
  state = self.state[p]
287
  num_params = len(params)
288
  shape, device, dtype = p.shape, p.device, p.dtype
289
+ if (
290
+ "momentum_buffer" not in state
291
+ or state["momentum_buffer"].shape[0] != num_params
292
+ or tuple(state["momentum_buffer"].shape[1:]) != tuple(shape)
293
+ ):
294
+ # If grad-bearing Muon params change (rare; usually all matrix params
295
+ # have grads), resize instead of crashing compiled Muon on a stale
296
+ # leading dimension. This preserves skip-None-grad semantics.
297
  state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
298
+ state.pop("second_momentum_buffer", None)
299
  red_dim = -1 if shape[-2] >= shape[-1] else -2
300
  if "second_momentum_buffer" not in state:
301
  # Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
overlay/hydra/training.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import sys
10
  import threading
11
  import time
12
- from dataclasses import asdict
13
  from pathlib import Path
14
 
15
  import torch
@@ -103,6 +103,22 @@ _CONTRASTIVE_CTX_LEN = int(os.environ.get("HYDRA_CONTRASTIVE_CTX_LEN", "8"))
103
  _CONTRASTIVE_N_PAIRS = int(os.environ.get("HYDRA_CONTRASTIVE_N_PAIRS", "256"))
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # ---------------------------------------------------------------------------
107
  # Schedules
108
  # ---------------------------------------------------------------------------
@@ -136,6 +152,7 @@ def save_ckpt(
136
  *,
137
  val_bpb: float | None = None,
138
  ) -> None:
 
139
  try:
140
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
141
  payload = {
@@ -289,7 +306,22 @@ def maybe_resume_ckpt(
289
  def main() -> None:
290
  t_start = time.time()
291
  torch.manual_seed(SEED)
292
- torch.cuda.manual_seed(SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  # Precision / kernel-selection knobs for peak throughput on Ampere.
294
  # - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
295
  # - allow_tf32 : explicit for both matmul + cudnn paths
@@ -299,12 +331,6 @@ def main() -> None:
299
  # over the first ~100 steps. Observed 2026-04-22 and confirmed by
300
  # differential profiling. Default is now FALSE; set =1 only if you
301
  # see a specific workload where benchmark helps sustained tps.
302
- torch.set_float32_matmul_precision("high")
303
- torch.backends.cuda.matmul.allow_tf32 = True
304
- torch.backends.cudnn.allow_tf32 = True
305
- torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
306
- device = torch.device("cuda")
307
- autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
308
 
309
  # Streaming path skips prepare.py (which normally trains the tokenizer
310
  # and builds the retina), so we must materialize both before model init.
@@ -435,7 +461,7 @@ def main() -> None:
435
  )
436
  _train_phase("dataloader_prefetch_start")
437
  train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
438
- if step > 0 and os.environ.get("HYDRA_RESUME_SKIP_DATALOADER", "1") == "1":
439
  _skip_micro_batches = step * grad_accum_steps
440
  print(f"[resume] fast-forwarding train stream micro_batches={_skip_micro_batches} step={step} grad_accum={grad_accum_steps}", flush=True)
441
  for _skip_i in range(_skip_micro_batches):
@@ -469,13 +495,11 @@ def main() -> None:
469
  _ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
470
  _som_thread: threading.Thread | None = None
471
  _hestia_thread: threading.Thread | None = None
472
- _hestia_stream: torch.cuda.Stream | None = (
473
- torch.cuda.Stream() if _ASYNC_POSTPROCESS else None
474
- )
475
 
476
  # Hebbian retina mode — per-step on-GPU update, mutually exclusive with SOM.
477
  # Activated by env HYDRA_HEBBIAN_RETINA=1 (default off).
478
- _HEBBIAN_RETINA = os.environ.get("HYDRA_HEBBIAN_RETINA", "0") == "1"
479
  _HEBBIAN_ALPHA = float(os.environ.get("HYDRA_HEBBIAN_ALPHA", "0.001"))
480
  _prof = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
481
  if _HEBBIAN_RETINA:
@@ -514,6 +538,32 @@ def main() -> None:
514
  # default cadence) instead of every step.
515
  nan_flag = torch.zeros((), device=device, dtype=torch.bool)
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  _first_step_marker_emitted = False
518
  while True:
519
  if not _first_step_marker_emitted:
@@ -608,18 +658,9 @@ def main() -> None:
608
 
609
  # A10G Hyena fallback can produce finite forward loss but non-finite
610
  # gradients through the guarded residual path on the next optimizer
611
- # step. Scrub non-finite grad entries before clipping/stepping so one
612
- # bad native-kernel backward value cannot poison the entire parameter
613
- # state and create step=1 train_loss=nan.
614
- # Fast GPU-native grad guard
615
- if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
616
- with torch.no_grad():
617
- for p in model.parameters():
618
- if p.grad is not None:
619
- p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
620
-
621
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
622
- optimizer.step()
623
  if _prof:
624
  torch.cuda.synchronize(); _t_opt = time.time()
625
 
 
9
  import sys
10
  import threading
11
  import time
12
+ from dataclasses import asdict, fields
13
  from pathlib import Path
14
 
15
  import torch
 
103
  _CONTRASTIVE_N_PAIRS = int(os.environ.get("HYDRA_CONTRASTIVE_N_PAIRS", "256"))
104
 
105
 
106
+ def config_from_dict(payload: dict) -> PostSemClawConfig:
107
+ """Rebuild PostSemClawConfig from a checkpoint payload dict.
108
+
109
+ Checkpoints can contain older configs without newer dataclass fields, or
110
+ future configs with unknown fields. Keep loading permissive, but normalize
111
+ tuple-backed topology fields so Hyena/GDN layer selections survive JSON or
112
+ pickle paths that turn tuples into lists.
113
+ """
114
+ field_names = {field.name for field in fields(PostSemClawConfig)}
115
+ kwargs = {key: value for key, value in payload.items() if key in field_names}
116
+ for tuple_key in ("hyena_layers", "gdn_layers"):
117
+ if tuple_key in kwargs and kwargs[tuple_key] is not None:
118
+ kwargs[tuple_key] = tuple(kwargs[tuple_key])
119
+ return PostSemClawConfig(**kwargs)
120
+
121
+
122
  # ---------------------------------------------------------------------------
123
  # Schedules
124
  # ---------------------------------------------------------------------------
 
152
  *,
153
  val_bpb: float | None = None,
154
  ) -> None:
155
+ global _CKPT_WORKER_THREAD
156
  try:
157
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
158
  payload = {
 
306
  def main() -> None:
307
  t_start = time.time()
308
  torch.manual_seed(SEED)
309
+ device_str = "cuda" if torch.cuda.is_available() else "cpu"
310
+ device = torch.device(device_str)
311
+ if device_str == "cuda":
312
+ torch.cuda.manual_seed(SEED)
313
+ torch.set_float32_matmul_precision("high")
314
+ torch.backends.cuda.matmul.allow_tf32 = True
315
+ torch.backends.cudnn.allow_tf32 = True
316
+ torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
317
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
318
+ else:
319
+ # CPU path: limit BLAS threads to avoid oversubscription with data workers.
320
+ _cpu_threads = int(os.environ.get("HYDRA_CPU_THREADS", str(min(os.cpu_count() or 4, 8))))
321
+ torch.set_num_threads(_cpu_threads)
322
+ print(f"[CPU] torch.set_num_threads={_cpu_threads}")
323
+ autocast_ctx = torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=False)
324
+
325
  # Precision / kernel-selection knobs for peak throughput on Ampere.
326
  # - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
327
  # - allow_tf32 : explicit for both matmul + cudnn paths
 
331
  # over the first ~100 steps. Observed 2026-04-22 and confirmed by
332
  # differential profiling. Default is now FALSE; set =1 only if you
333
  # see a specific workload where benchmark helps sustained tps.
 
 
 
 
 
 
334
 
335
  # Streaming path skips prepare.py (which normally trains the tokenizer
336
  # and builds the retina), so we must materialize both before model init.
 
461
  )
462
  _train_phase("dataloader_prefetch_start")
463
  train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
464
+ if step > 0 and os.environ.get("HYDRA_RESUME_SKIP_DATALOADER", "1") != "1":
465
  _skip_micro_batches = step * grad_accum_steps
466
  print(f"[resume] fast-forwarding train stream micro_batches={_skip_micro_batches} step={step} grad_accum={grad_accum_steps}", flush=True)
467
  for _skip_i in range(_skip_micro_batches):
 
495
  _ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
496
  _som_thread: threading.Thread | None = None
497
  _hestia_thread: threading.Thread | None = None
498
+ _hestia_stream = torch.cuda.Stream() if (_ASYNC_POSTPROCESS and device.type == "cuda") else None
 
 
499
 
500
  # Hebbian retina mode — per-step on-GPU update, mutually exclusive with SOM.
501
  # Activated by env HYDRA_HEBBIAN_RETINA=1 (default off).
502
+ _HEBBIAN_RETINA = device.type == "cuda" and os.environ.get("HYDRA_HEBBIAN_RETINA", "0") == "1"
503
  _HEBBIAN_ALPHA = float(os.environ.get("HYDRA_HEBBIAN_ALPHA", "0.001"))
504
  _prof = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
505
  if _HEBBIAN_RETINA:
 
538
  # default cadence) instead of every step.
539
  nan_flag = torch.zeros((), device=device, dtype=torch.bool)
540
 
541
+ # Device-step fusion surface: cache the parameter walk once and keep the
542
+ # finite-grad guard + clipping + optimizer launch in one compact boundary.
543
+ # This avoids re-materializing model.parameters() twice per optimizer step
544
+ # and gives the A10G path a single toggleable fused-step block without
545
+ # pulling dataloader/checkpoint/logging CPU control flow into Dynamo.
546
+ _HYDRA_FUSED_DEVICE_STEP = os.environ.get("HYDRA_FUSED_DEVICE_STEP", "1") == "1"
547
+ _trainable_params = tuple(model.parameters())
548
+
549
+ def _finish_device_step():
550
+ if _HYDRA_FUSED_DEVICE_STEP:
551
+ if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
552
+ with torch.no_grad():
553
+ for _p in _trainable_params:
554
+ if _p.grad is not None:
555
+ _p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
556
+ torch.nn.utils.clip_grad_norm_(_trainable_params, max_norm=1.0)
557
+ optimizer.step()
558
+ return
559
+ if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
560
+ with torch.no_grad():
561
+ for _p in model.parameters():
562
+ if _p.grad is not None:
563
+ _p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
564
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
565
+ optimizer.step()
566
+
567
  _first_step_marker_emitted = False
568
  while True:
569
  if not _first_step_marker_emitted:
 
658
 
659
  # A10G Hyena fallback can produce finite forward loss but non-finite
660
  # gradients through the guarded residual path on the next optimizer
661
+ # step. The fused device-step boundary scrubs, clips, and launches the
662
+ # optimizer without re-walking model.parameters() on every substage.
663
+ _finish_device_step()
 
 
 
 
 
 
 
 
 
664
  if _prof:
665
  torch.cuda.synchronize(); _t_opt = time.time()
666
 
overlay/kernels/__init__.py ADDED
File without changes
overlay/kernels/cuda/decode_kernels.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * CuTe DSL decode kernels for Mamba-3 autoregressive generation.
3
+ *
4
+ * Phase 2: Optimized single-token SSM step for inference.
5
+ * Phase 1: Not needed (training only, no generation).
6
+ *
7
+ * Fuses: input_proj + conv_step + ssm_step + output_proj
8
+ * into a single kernel launch for minimal latency.
9
+ */
10
+ // Stub: Phase 2 implementation
overlay/kernels/cuda/flashfftconv/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
overlay/kernels/cuda/flashfftconv/README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flashfftconv (vendored)
2
+
3
+ Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
4
+
5
+ **Upstream commit:** see `UPSTREAM_COMMIT`.
6
+
7
+ ## What this is
8
+
9
+ HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
10
+ drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
11
+ faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
12
+ 1024, 2048, 4096, 8192, ..., up to 4M).
13
+
14
+ In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
15
+ accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
16
+ unchanged (pure PyTorch fallback).
17
+
18
+ ## How to build
19
+
20
+ The vendored tree contains:
21
+ - `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
22
+ - `csrc/` — CUDA source files and setup.py for the native extension
23
+
24
+ Build instructions:
25
+
26
+ ```bash
27
+ cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
28
+
29
+ # Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
30
+ # (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
31
+ # cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
32
+
33
+ # Build with the local CUDA toolchain (must match your torch.version.cuda):
34
+ CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
35
+ ```
36
+
37
+ Then install the Python wrappers:
38
+
39
+ ```bash
40
+ cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
41
+ .venv/bin/pip install -e .
42
+ ```
43
+
44
+ ## Runtime usage
45
+
46
+ Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
47
+ `subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
48
+ and falls back to pure PyTorch on import failure.
49
+
50
+ ## Known caveats
51
+
52
+ - Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
53
+ 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
54
+ For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
55
+ - dtype must be fp16 or bf16 (fp32 not supported).
56
+ - GPU arch must be compiled into the extension (see setup.py cc_flag).
57
+ - CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT ADDED
@@ -0,0 +1 @@
 
 
1
+ b8771028717f46d5b22cbb8e12833f35033d621b
overlay/kernels/cuda/flashfftconv/csrc/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.npy
2
+ *.json
3
+ *.png
4
+
5
+ */*.npy
6
+ */*.json
7
+ */*.png
8
+
9
+ *.DS_Store
10
+ */*.DS_Store
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
10
+ #define CHECK_INPUT(x) \
11
+ CHECK_CUDA(x); \
12
+ CHECK_CONTIGUOUS(x); \
13
+ CHECK_IS_HALF_OR_BFLOAT(x)
14
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
15
+
16
+
17
+ std::vector<torch::Tensor> butterfly_cuda(
18
+ torch::Tensor x,
19
+ torch::Tensor d_f_T,
20
+ torch::Tensor twiddle_factors_real,
21
+ torch::Tensor twiddle_factors_imag,
22
+ std::optional<at::Tensor> x_gate = std::nullopt
23
+ );
24
+
25
+
26
+ std::vector<torch::Tensor> butterfly_bf16_cuda(
27
+ torch::Tensor x,
28
+ torch::Tensor d_f_T_real,
29
+ torch::Tensor d_f_T_imag,
30
+ torch::Tensor twiddle_factors_real,
31
+ torch::Tensor twiddle_factors_imag,
32
+ std::optional<at::Tensor> out_gate = std::nullopt
33
+ );
34
+
35
+
36
+ std::vector<torch::Tensor> butterfly_padded_cuda(
37
+ torch::Tensor x,
38
+ torch::Tensor d_f_T,
39
+ torch::Tensor twiddle_factors_real,
40
+ torch::Tensor twiddle_factors_imag,
41
+ int M,
42
+ std::optional<at::Tensor> x_gate = std::nullopt
43
+ );
44
+
45
+
46
+ std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
47
+ torch::Tensor x,
48
+ torch::Tensor d_f_T_real,
49
+ torch::Tensor d_f_T_imag,
50
+ torch::Tensor twiddle_factors_real,
51
+ torch::Tensor twiddle_factors_imag,
52
+ int M,
53
+ std::optional<at::Tensor> x_gate = std::nullopt
54
+ );
55
+
56
+ torch::Tensor butterfly_ifft_cuda(
57
+ torch::Tensor x_real,
58
+ torch::Tensor x_imag,
59
+ torch::Tensor d_f_T,
60
+ torch::Tensor twiddle_factors_real,
61
+ torch::Tensor twiddle_factors_imag,
62
+ std::optional<at::Tensor> out_gate = std::nullopt
63
+ );
64
+
65
+ torch::Tensor butterfly_ifft_bf16_cuda(
66
+ torch::Tensor x_real,
67
+ torch::Tensor x_imag,
68
+ torch::Tensor d_f_real,
69
+ torch::Tensor d_f_imag,
70
+ torch::Tensor twiddle_factors_real,
71
+ torch::Tensor twiddle_factors_imag,
72
+ std::optional<at::Tensor> x_gate = std::nullopt
73
+ );
74
+
75
+ torch::Tensor butterfly_ifft_padded_cuda(
76
+ torch::Tensor x_real,
77
+ torch::Tensor x_imag,
78
+ torch::Tensor d_f,
79
+ torch::Tensor twiddle_factors_real,
80
+ torch::Tensor twiddle_factors_imag,
81
+ int N,
82
+ std::optional<at::Tensor> out_gate = std::nullopt
83
+ );
84
+
85
+
86
+ torch::Tensor butterfly_ifft_padded_bf16_cuda(
87
+ torch::Tensor x_real,
88
+ torch::Tensor x_imag,
89
+ torch::Tensor d_f_real,
90
+ torch::Tensor d_f_imag,
91
+ torch::Tensor twiddle_factors_real,
92
+ torch::Tensor twiddle_factors_imag,
93
+ int N,
94
+ std::optional<at::Tensor> out_gate = std::nullopt
95
+ );
96
+
97
+ std::vector<torch::Tensor> butterfly(
98
+ torch::Tensor x,
99
+ torch::Tensor d_f_T,
100
+ torch::Tensor twiddle_factors_real,
101
+ torch::Tensor twiddle_factors_imag
102
+ ){
103
+ CHECK_INPUT(x);
104
+ CHECK_INPUT(twiddle_factors_real);
105
+ CHECK_INPUT(twiddle_factors_imag);
106
+
107
+
108
+ return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
109
+ }
110
+
111
+ std::vector<torch::Tensor> butterfly_gated(
112
+ torch::Tensor x,
113
+ torch::Tensor d_f_T,
114
+ torch::Tensor twiddle_factors_real,
115
+ torch::Tensor twiddle_factors_imag,
116
+ torch::Tensor x_gate
117
+ ){
118
+ CHECK_INPUT(x);
119
+ CHECK_INPUT(twiddle_factors_real);
120
+ CHECK_INPUT(twiddle_factors_imag);
121
+
122
+ CHECK_INPUT(x_gate);
123
+
124
+ return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
125
+ }
126
+
127
+ std::vector<torch::Tensor> butterfly_bf16(
128
+ torch::Tensor x,
129
+ torch::Tensor d_f_T_real,
130
+ torch::Tensor d_f_T_imag,
131
+ torch::Tensor twiddle_factors_real,
132
+ torch::Tensor twiddle_factors_imag
133
+ ){
134
+ CHECK_INPUT(x);
135
+ CHECK_INPUT(twiddle_factors_real);
136
+ CHECK_INPUT(twiddle_factors_imag);
137
+ CHECK_INPUT(d_f_T_real);
138
+ CHECK_INPUT(d_f_T_imag);
139
+
140
+
141
+ return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
142
+ }
143
+
144
+ std::vector<torch::Tensor> butterfly_gated_bf16(
145
+ torch::Tensor x,
146
+ torch::Tensor d_f_T_real,
147
+ torch::Tensor d_f_T_imag,
148
+ torch::Tensor twiddle_factors_real,
149
+ torch::Tensor twiddle_factors_imag,
150
+ torch::Tensor x_gate
151
+ ){
152
+ CHECK_INPUT(x);
153
+ CHECK_INPUT(twiddle_factors_real);
154
+ CHECK_INPUT(twiddle_factors_imag);
155
+ CHECK_INPUT(d_f_T_real);
156
+ CHECK_INPUT(d_f_T_imag);
157
+ CHECK_INPUT(x_gate);
158
+
159
+
160
+ return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
161
+ }
162
+
163
+ torch::Tensor butterfly_ifft(
164
+ torch::Tensor x_real,
165
+ torch::Tensor x_imag,
166
+ torch::Tensor d_f_T,
167
+ torch::Tensor twiddle_factors_real,
168
+ torch::Tensor twiddle_factors_imag
169
+ ){
170
+ CHECK_INPUT(x_real);
171
+ CHECK_INPUT(x_imag);
172
+ CHECK_INPUT(twiddle_factors_real);
173
+ CHECK_INPUT(twiddle_factors_imag);
174
+
175
+ return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
176
+ }
177
+
178
+
179
+ torch::Tensor butterfly_ifft_gated(
180
+ torch::Tensor x_real,
181
+ torch::Tensor x_imag,
182
+ torch::Tensor d_f_T,
183
+ torch::Tensor twiddle_factors_real,
184
+ torch::Tensor twiddle_factors_imag,
185
+ torch::Tensor out_gate
186
+ ){
187
+ CHECK_INPUT(x_real);
188
+ CHECK_INPUT(x_imag);
189
+ CHECK_INPUT(twiddle_factors_real);
190
+ CHECK_INPUT(twiddle_factors_imag);
191
+ CHECK_INPUT(out_gate);
192
+
193
+ return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
194
+ }
195
+
196
+ torch::Tensor butterfly_ifft_bf16(
197
+ torch::Tensor x_real,
198
+ torch::Tensor x_imag,
199
+ torch::Tensor d_f_real,
200
+ torch::Tensor d_f_imag,
201
+ torch::Tensor twiddle_factors_real,
202
+ torch::Tensor twiddle_factors_imag
203
+ ){
204
+ CHECK_INPUT(x_real);
205
+ CHECK_INPUT(x_imag);
206
+ CHECK_INPUT(d_f_real);
207
+ CHECK_INPUT(d_f_imag);
208
+ CHECK_INPUT(twiddle_factors_real);
209
+ CHECK_INPUT(twiddle_factors_imag);
210
+
211
+
212
+ return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
213
+ }
214
+
215
+
216
+ torch::Tensor butterfly_ifft_gated_bf16(
217
+ torch::Tensor x_real,
218
+ torch::Tensor x_imag,
219
+ torch::Tensor d_f_real,
220
+ torch::Tensor d_f_imag,
221
+ torch::Tensor twiddle_factors_real,
222
+ torch::Tensor twiddle_factors_imag,
223
+ torch::Tensor out_gate
224
+ ){
225
+ CHECK_INPUT(x_real);
226
+ CHECK_INPUT(x_imag);
227
+ CHECK_INPUT(d_f_real);
228
+ CHECK_INPUT(d_f_imag);
229
+ CHECK_INPUT(twiddle_factors_real);
230
+ CHECK_INPUT(twiddle_factors_imag);
231
+ CHECK_INPUT(out_gate);
232
+
233
+ return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
234
+ }
235
+
236
+ std::vector<torch::Tensor> butterfly_padded(
237
+ torch::Tensor x,
238
+ torch::Tensor d_f_T,
239
+ torch::Tensor twiddle_factors_real,
240
+ torch::Tensor twiddle_factors_imag,
241
+ int M
242
+ ){
243
+ CHECK_INPUT(x);
244
+ CHECK_INPUT(twiddle_factors_real);
245
+ CHECK_INPUT(twiddle_factors_imag);
246
+
247
+
248
+ return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
249
+ }
250
+
251
+ std::vector<torch::Tensor> butterfly_padded_bf16(
252
+ torch::Tensor x,
253
+ torch::Tensor d_f_T_real,
254
+ torch::Tensor d_f_T_imag,
255
+ torch::Tensor twiddle_factors_real,
256
+ torch::Tensor twiddle_factors_imag,
257
+ int M
258
+ ){
259
+ CHECK_INPUT(x);
260
+ CHECK_INPUT(twiddle_factors_real);
261
+ CHECK_INPUT(twiddle_factors_imag);
262
+
263
+
264
+ return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
265
+ }
266
+
267
+
268
+ std::vector<torch::Tensor> butterfly_padded_gated(
269
+ torch::Tensor x,
270
+ torch::Tensor d_f_T,
271
+ torch::Tensor twiddle_factors_real,
272
+ torch::Tensor twiddle_factors_imag,
273
+ int M,
274
+ torch::Tensor x_gate
275
+ ){
276
+ CHECK_INPUT(x);
277
+ CHECK_INPUT(twiddle_factors_real);
278
+ CHECK_INPUT(twiddle_factors_imag);
279
+
280
+
281
+ return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
282
+ }
283
+
284
+ std::vector<torch::Tensor> butterfly_padded_gated_bf16(
285
+ torch::Tensor x,
286
+ torch::Tensor d_f_T_real,
287
+ torch::Tensor d_f_T_imag,
288
+ torch::Tensor twiddle_factors_real,
289
+ torch::Tensor twiddle_factors_imag,
290
+ int M,
291
+ torch::Tensor x_gate
292
+ ){
293
+ CHECK_INPUT(x);
294
+ CHECK_INPUT(d_f_T_real);
295
+ CHECK_INPUT(d_f_T_imag);
296
+ CHECK_INPUT(twiddle_factors_real);
297
+ CHECK_INPUT(twiddle_factors_imag);
298
+
299
+
300
+ return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
301
+ }
302
+
303
+ torch::Tensor butterfly_ifft_padded(
304
+ torch::Tensor x_real,
305
+ torch::Tensor x_imag,
306
+ torch::Tensor d_f,
307
+ torch::Tensor twiddle_factors_real,
308
+ torch::Tensor twiddle_factors_imag,
309
+ int N
310
+ ){
311
+ CHECK_INPUT(x_real);
312
+ CHECK_INPUT(x_imag);
313
+ CHECK_INPUT(twiddle_factors_real);
314
+ CHECK_INPUT(twiddle_factors_imag);
315
+
316
+ return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
317
+ }
318
+
319
+ torch::Tensor butterfly_ifft_padded_gated(
320
+ torch::Tensor x_real,
321
+ torch::Tensor x_imag,
322
+ torch::Tensor d_f,
323
+ torch::Tensor twiddle_factors_real,
324
+ torch::Tensor twiddle_factors_imag,
325
+ int N,
326
+ torch::Tensor out_gate
327
+ ){
328
+ CHECK_INPUT(x_real);
329
+ CHECK_INPUT(x_imag);
330
+ CHECK_INPUT(twiddle_factors_real);
331
+ CHECK_INPUT(twiddle_factors_imag);
332
+
333
+ return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
334
+ }
335
+
336
+
337
+ torch::Tensor butterfly_ifft_padded_bf16(
338
+ torch::Tensor x_real,
339
+ torch::Tensor x_imag,
340
+ torch::Tensor d_f_real,
341
+ torch::Tensor d_f_imag,
342
+ torch::Tensor twiddle_factors_real,
343
+ torch::Tensor twiddle_factors_imag,
344
+ int N
345
+ ){
346
+ CHECK_INPUT(x_real);
347
+ CHECK_INPUT(x_imag);
348
+ CHECK_INPUT(d_f_real);
349
+ CHECK_INPUT(d_f_imag);
350
+ CHECK_INPUT(twiddle_factors_real);
351
+ CHECK_INPUT(twiddle_factors_imag);
352
+
353
+ return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
354
+ }
355
+
356
+ torch::Tensor butterfly_ifft_padded_gated_bf16(
357
+ torch::Tensor x_real,
358
+ torch::Tensor x_imag,
359
+ torch::Tensor d_f_real,
360
+ torch::Tensor d_f_imag,
361
+ torch::Tensor twiddle_factors_real,
362
+ torch::Tensor twiddle_factors_imag,
363
+ int N,
364
+ torch::Tensor out_gate
365
+ ){
366
+ CHECK_INPUT(x_real);
367
+ CHECK_INPUT(x_imag);
368
+ CHECK_INPUT(d_f_real);
369
+ CHECK_INPUT(d_f_imag);
370
+ CHECK_INPUT(twiddle_factors_real);
371
+ CHECK_INPUT(twiddle_factors_imag);
372
+
373
+ return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
374
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ __global__ void butterfly_cuda_kernel_64(
15
+ const __half2 *__restrict__ x,
16
+ const __half2 *__restrict__ x_gate,
17
+ const complex_half_t *__restrict__ d_f,
18
+ const __half2 *__restrict__ twiddle_factors_real,
19
+ const __half2 *__restrict__ twiddle_factors_imag,
20
+ __half2 *__restrict__ out_real,
21
+ __half2 *__restrict__ out_imag,
22
+ uint B,
23
+ uint H,
24
+ int N)
25
+ {
26
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
+ int idx;
29
+ int shared_offset;
30
+ const int B_Y = blockDim.y;
31
+ const int n = N / B_Y;
32
+
33
+
34
+ extern __shared__ half x_shared[];
35
+ half *d_f_real = &x_shared[N * N];
36
+ half *d_f_imag = &d_f_real[N * N];
37
+ half *twiddles_real_shared = &d_f_imag[N * N];
38
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
+ half *out_real_shared = &twiddles_imag_shared[N * N];
40
+ half *out_imag_shared = &out_real_shared[N * N];
41
+
42
+ // #pragma unroll
43
+ for (int i = 0; i < n; i++)
44
+ {
45
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
46
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
47
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
48
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
49
+
50
+ // #pragma unroll
51
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
52
+ d_f_real[shared_offset] = d_f[shared_offset].real();
53
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
54
+
55
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
56
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
57
+ }
58
+
59
+ __half2 tmp_real, tmp_imag;
60
+
61
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
62
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
63
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
64
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
65
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
66
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
67
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
68
+
69
+ __syncthreads();
70
+
71
+ for (int i = 0; i < 4; i++)
72
+ {
73
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
74
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
75
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
76
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
77
+ }
78
+
79
+ for (int t = 0; t < 16; t++)
80
+ {
81
+
82
+ for (int i = 0; i < n; i++)
83
+ {
84
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
85
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
86
+ if(x_gate != nullptr){
87
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
88
+ }else{
89
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
90
+ }
91
+ }
92
+
93
+ __syncthreads();
94
+
95
+ for (int i = 0; i < 4; i++)
96
+ {
97
+ for (int j = 0; j < 4; j++)
98
+ {
99
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
100
+ }
101
+ }
102
+
103
+ #pragma unroll
104
+ for (int j = 0; j < 4; j++)
105
+ {
106
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
107
+
108
+ for (int k = 0; k < 4; k++)
109
+ {
110
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
111
+ }
112
+ }
113
+
114
+ #pragma unroll
115
+
116
+ for (int j = 0; j < 4; j++)
117
+ {
118
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
119
+
120
+ for (int k = 0; k < 4; k++)
121
+ {
122
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
123
+ }
124
+ }
125
+
126
+ #pragma unroll
127
+ for (int j = 0; j < 4; j++)
128
+ {
129
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
130
+ {
131
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
132
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
133
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
134
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
135
+ }
136
+
137
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
+ }
140
+
141
+ __syncthreads();
142
+
143
+ #pragma unroll
144
+ for (int i = 0; i < n; i++)
145
+ {
146
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
148
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
149
+ }
150
+
151
+ __syncthreads();
152
+ }
153
+ }
154
+
155
+ __global__ void butterfly_cuda_kernel_32(
156
+ const __half2 *__restrict__ x,
157
+ const __half2 *__restrict__ x_gate,
158
+ const complex_half_t *__restrict__ d_f,
159
+ const __half2 *__restrict__ twiddle_factors_real,
160
+ const __half2 *__restrict__ twiddle_factors_imag,
161
+ __half2 *__restrict__ out_real,
162
+ __half2 *__restrict__ out_imag,
163
+ uint B,
164
+ uint H,
165
+ int N)
166
+ {
167
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
168
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
169
+ int idx;
170
+
171
+ int shared_offset;
172
+ const int B_Y = blockDim.y;
173
+ const int n = N / B_Y;
174
+
175
+
176
+ __shared__ half x_shared[32 * 64];
177
+ __shared__ half d_f_real[32 * 32];
178
+ __shared__ half d_f_imag[32 * 32];
179
+ __shared__ half twiddles_real_shared[32 * 64];
180
+ __shared__ half twiddles_imag_shared[32 * 64];
181
+ __shared__ half out_real_shared[32 * 64];
182
+ __shared__ half out_imag_shared[32 * 64];
183
+
184
+ // #pragma unroll
185
+ for (int i = 0; i < n; i++)
186
+ {
187
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
188
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
189
+ if(x_gate == nullptr){
190
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
191
+ }else{
192
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
193
+ }
194
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
+
197
+ // #pragma unroll
198
+ d_f_real[shared_offset] = d_f[shared_offset].real();
199
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
200
+ }
201
+
202
+ __syncthreads();
203
+
204
+ if (threadIdx.y < N / 16)
205
+ {
206
+ __half2 tmp_real, tmp_imag;
207
+
208
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
209
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
210
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
212
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
213
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
214
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
215
+
216
+ int t = threadIdx.y * 32;
217
+
218
+ for (int i = 0; i < 2; i++)
219
+ {
220
+ for (int j = 0; j < 2; j++)
221
+ {
222
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
223
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
224
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
225
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
+ }
228
+ }
229
+
230
+ #pragma unroll
231
+ for (int i = 0; i < 2; i++)
232
+ {
233
+ for (int j = 0; j < 2; j++)
234
+ {
235
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
236
+
237
+ for (int k = 0; k < 2; k++)
238
+ {
239
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
240
+ }
241
+ }
242
+ }
243
+
244
+ #pragma unroll
245
+ for (int i = 0; i < 2; i++)
246
+ {
247
+ for (int j = 0; j < 2; j++)
248
+ {
249
+ wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
250
+
251
+ for (int k = 0; k < 2; k++)
252
+ {
253
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
254
+ }
255
+ }
256
+ }
257
+
258
+ #pragma unroll
259
+ for (int i = 0; i < 2; i++)
260
+ {
261
+ for (int j = 0; j < 2; j++)
262
+ {
263
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
264
+ {
265
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
266
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
267
+ reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
268
+ reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
269
+ }
270
+
271
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
+ }
274
+ }
275
+ }
276
+
277
+ __syncthreads();
278
+
279
+ #pragma unroll
280
+ for (int i = 0; i < n; i++)
281
+ {
282
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
284
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
285
+ }
286
+ }
287
+
288
+ __global__ void butterfly_cuda_kernel_128(
289
+ const __half2 *__restrict__ x,
290
+ const __half2 *__restrict__ x_gate,
291
+ const complex_half_t *__restrict__ d_f,
292
+ const __half2 *__restrict__ twiddle_factors_real,
293
+ const __half2 *__restrict__ twiddle_factors_imag,
294
+ __half2 *__restrict__ out_real,
295
+ __half2 *__restrict__ out_imag,
296
+ uint B,
297
+ uint H,
298
+ int N)
299
+ {
300
+ const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
301
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
302
+ int idx;
303
+
304
+ int shared_offset;
305
+ const int B_Y = blockDim.y;
306
+ const int n = N / B_Y;
307
+
308
+
309
+ extern __shared__ half shared_real[];
310
+ half *shared_imag = &shared_real[128 * 128];
311
+
312
+
313
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
314
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
315
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
316
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
317
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
318
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
319
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
320
+
321
+ for (int i = 0; i < n; i++)
322
+ {
323
+ for(int j=0; j< 4; j++){
324
+ shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
325
+ shared_real[shared_offset] = d_f[shared_offset].real();
326
+ shared_imag[shared_offset] = d_f[shared_offset].imag();
327
+ }
328
+ }
329
+
330
+ __syncthreads();
331
+
332
+
333
+ for (int i = 0; i < 8; i++){
334
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
335
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
336
+ }
337
+
338
+
339
+ __syncthreads();
340
+
341
+
342
+
343
+ for (int i = 0; i < n; i++)
344
+ {
345
+ for(int j=0; j< 2; j++){
346
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
347
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
348
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
349
+ reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
350
+ }
351
+ }
352
+
353
+ __syncthreads();
354
+
355
+
356
+ for (int i = 0; i < 8; i++){
357
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
358
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
359
+ }
360
+
361
+ __syncthreads();
362
+
363
+
364
+ for(int t=0; t< 16; t++){
365
+ for (int i = 0; i < n; i++)
366
+ {
367
+ for(int j=0; j< 2; j++){
368
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
369
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
370
+ if(x_gate != nullptr){
371
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
372
+ }else{
373
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
374
+ }
375
+
376
+ }
377
+ }
378
+
379
+
380
+ __syncthreads();
381
+
382
+
383
+ for (int i = 0; i < 8; i++)
384
+ {
385
+ for (int j = 0; j < 8; j++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
+ }
389
+ }
390
+
391
+ __syncthreads();
392
+
393
+ #pragma unroll
394
+ for (int j = 0; j < 8; j++)
395
+ {
396
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
397
+
398
+ for (int k = 0; k < 8; k++)
399
+ {
400
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
+ }
402
+ }
403
+
404
+ #pragma unroll
405
+
406
+ for (int j = 0; j < 8; j++)
407
+ {
408
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
409
+
410
+ for (int k = 0; k < 8; k++)
411
+ {
412
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
+ }
414
+ }
415
+
416
+ __half2 tmp_real, tmp_imag;
417
+ #pragma unroll
418
+ for (int j = 0; j < 8; j++)
419
+ {
420
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
+ {
422
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
423
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
424
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
425
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
426
+ }
427
+
428
+ wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
429
+ wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
430
+ }
431
+
432
+ __syncthreads();
433
+
434
+ #pragma unroll
435
+ for (int i = 0; i < n; i++)
436
+ {
437
+ for(int j=0; j< 2; j++){
438
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
439
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
440
+ out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
441
+ out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
442
+ }
443
+ }
444
+
445
+ __syncthreads();
446
+ }
447
+ }
448
+
449
+
450
+ __global__ void butterfly_cuda_kernel_16(
451
+ const __half2 *__restrict__ x,
452
+ const __half2 *__restrict__ x_gate,
453
+ const complex_half_t *__restrict__ d_f,
454
+ const __half2 *__restrict__ twiddle_factors_real,
455
+ const __half2 *__restrict__ twiddle_factors_imag,
456
+ __half2 *__restrict__ out_real,
457
+ __half2 *__restrict__ out_imag,
458
+ uint B,
459
+ uint H,
460
+ int N)
461
+ {
462
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
463
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
464
+ int idx;
465
+
466
+ int shared_offset;
467
+ const int B_Y = blockDim.y;
468
+ const int n = N / B_Y;
469
+
470
+
471
+ __shared__ half x_shared[16 * 64];
472
+ __shared__ half d_f_real[16 * 16];
473
+ __shared__ half d_f_imag[16 * 16];
474
+ __shared__ half twiddles_real_shared[16 * 64];
475
+ __shared__ half twiddles_imag_shared[16 * 64];
476
+ __shared__ half out_real_shared[16 * 64];
477
+ __shared__ half out_imag_shared[16 * 64];
478
+
479
+ // #pragma unroll
480
+ for (int i = 0; i < n; i++)
481
+ {
482
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
483
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
484
+
485
+ if(x_gate != NULL)
486
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
487
+ else
488
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
489
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
490
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
491
+
492
+ // #pragma unroll
493
+
494
+ if(threadIdx.x < 16 ){
495
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
496
+ d_f_real[shared_offset] = d_f[shared_offset].real();
497
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
498
+ }
499
+ }
500
+
501
+ __syncthreads();
502
+
503
+ if (threadIdx.y < 4)
504
+ {
505
+ __half2 tmp_real, tmp_imag;
506
+
507
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
508
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
509
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
510
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
511
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
512
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
513
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
514
+
515
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
516
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
517
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
518
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
519
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
520
+
521
+
522
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
523
+
524
+
525
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
526
+
527
+
528
+ wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
529
+
530
+
531
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
532
+
533
+
534
+
535
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
536
+ {
537
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
538
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
539
+ reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
540
+ reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
541
+ }
542
+
543
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
544
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
545
+ }
546
+
547
+ __syncthreads();
548
+
549
+ #pragma unroll
550
+ for (int i = 0; i < n; i++)
551
+ {
552
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
553
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
554
+ out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
555
+ }
556
+ }
557
+
558
+
559
+ std::vector<torch::Tensor> butterfly_cuda(
560
+ torch::Tensor x,
561
+ torch::Tensor d_f,
562
+ torch::Tensor twiddle_factors_real,
563
+ torch::Tensor twiddle_factors_imag,
564
+ std::optional<at::Tensor> x_gate = std::nullopt)
565
+ {
566
+
567
+ uint B = x.size(0);
568
+ uint H = x.size(1);
569
+ // uint m = x.size(1);
570
+
571
+ // const int TILE_SIZE = 16;
572
+ uint N = x.size(2);
573
+ uint M = x.size(3);
574
+ dim3 gridDim;
575
+ dim3 blockDim;
576
+
577
+ gridDim.y = B;
578
+ gridDim.z = H;
579
+
580
+ torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
581
+ torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
582
+
583
+ //set blockDims
584
+ switch(N){
585
+ case 128:
586
+ blockDim.x = 32;
587
+ blockDim.y = 8;
588
+ break;
589
+ default:
590
+ blockDim.x = 32;
591
+ blockDim.y = 4;
592
+ break;
593
+ }
594
+
595
+ //set gridDim.x
596
+ switch(N){
597
+ case 128:
598
+ switch (M){
599
+ case 16384:
600
+ gridDim.x = 128;
601
+ break;
602
+ case 8192:
603
+ gridDim.x = 64;
604
+ break;
605
+ case 4096:
606
+ gridDim.x = 32;
607
+ break;
608
+ default:
609
+ gridDim.x = 256;
610
+ break;
611
+ }
612
+ break;
613
+ default:
614
+ switch (M){
615
+ case 16384:
616
+ gridDim.x = 256;
617
+ break;
618
+ case 8192:
619
+ gridDim.x = 128;
620
+ break;
621
+ case 4096:
622
+ gridDim.x = 64;
623
+ break;
624
+ default:
625
+ gridDim.x = 512;
626
+ break;
627
+ }
628
+ break;
629
+ }
630
+
631
+ switch (N)
632
+ {
633
+ case 16:
634
+ butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
635
+ static_cast<__half2 *>(x.data_ptr()),
636
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
637
+ static_cast<complex_half_t *>(d_f.data_ptr()),
638
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
639
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
640
+ static_cast<__half2 *>(out_real.data_ptr()),
641
+ static_cast<__half2 *>(out_imag.data_ptr()),
642
+ B,
643
+ H,
644
+ N);
645
+ break;
646
+ case 32:
647
+ butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
648
+ static_cast<__half2 *>(x.data_ptr()),
649
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
650
+ static_cast<complex_half_t *>(d_f.data_ptr()),
651
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
652
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
653
+ static_cast<__half2 *>(out_real.data_ptr()),
654
+ static_cast<__half2 *>(out_imag.data_ptr()),
655
+ B,
656
+ H,
657
+ N);
658
+ break;
659
+
660
+ case 64:
661
+ gridDim.z = H / 16;
662
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
663
+
664
+ butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
665
+ static_cast<__half2 *>(x.data_ptr()),
666
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
667
+ static_cast<complex_half_t *>(d_f.data_ptr()),
668
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
669
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
670
+ static_cast<__half2 *>(out_real.data_ptr()),
671
+ static_cast<__half2 *>(out_imag.data_ptr()),
672
+ B,
673
+ H,
674
+ N);
675
+ break;
676
+ case 128:
677
+ gridDim.z = H / 16;
678
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
679
+
680
+ butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
681
+ static_cast<__half2 *>(x.data_ptr()),
682
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
683
+ static_cast<complex_half_t *>(d_f.data_ptr()),
684
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
685
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
686
+ static_cast<__half2 *>(out_real.data_ptr()),
687
+ static_cast<__half2 *>(out_imag.data_ptr()),
688
+ B,
689
+ H,
690
+ N);
691
+ break;
692
+
693
+ default:
694
+ printf("Not yet implemented \n");
695
+ break;
696
+ }
697
+
698
+ return {out_real, out_imag};
699
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_runtime.h>
9
+ #include <cuda_fp16.h>
10
+ #include <cuda_bf16.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+ __global__ void butterfly_cuda_kernel_64(
16
+ const __nv_bfloat162 *__restrict__ x,
17
+ const __nv_bfloat162 *__restrict__ x_gate,
18
+ const __nv_bfloat162 *__restrict__ d_f_real,
19
+ const __nv_bfloat162 *__restrict__ d_f_imag,
20
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
+ __nv_bfloat162 *__restrict__ out_real,
23
+ __nv_bfloat162 *__restrict__ out_imag,
24
+ uint B,
25
+ uint H,
26
+ int N)
27
+ {
28
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
+ int idx;
31
+ int shared_offset;
32
+ const int B_Y = blockDim.y;
33
+ const int n = N / B_Y;
34
+
35
+
36
+ extern __shared__ __nv_bfloat16 x_shared[];
37
+ __nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
38
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
+ float *out_imag_shared = &out_real_shared[N * N];
43
+
44
+ // #pragma unroll
45
+ for (int i = 0; i < n; i++)
46
+ {
47
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
48
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
49
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
50
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
51
+
52
+ // #pragma unroll
53
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
54
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
55
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
56
+ }
57
+
58
+ float2 tmp_real, tmp_imag;
59
+
60
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
61
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
62
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
63
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
64
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
65
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
66
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
67
+
68
+ __syncthreads();
69
+
70
+ for (int i = 0; i < 4; i++)
71
+ {
72
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
73
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
74
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
75
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
76
+ }
77
+
78
+ for (int t = 0; t < 16; t++)
79
+ {
80
+
81
+ for (int i = 0; i < n; i++)
82
+ {
83
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
84
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
85
+ if(x_gate != nullptr){
86
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
87
+ }else{
88
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
89
+ }
90
+ }
91
+
92
+ __syncthreads();
93
+
94
+ for (int i = 0; i < 4; i++)
95
+ {
96
+ for (int j = 0; j < 4; j++)
97
+ {
98
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
99
+ }
100
+ }
101
+
102
+ #pragma unroll
103
+ for (int j = 0; j < 4; j++)
104
+ {
105
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
106
+
107
+ for (int k = 0; k < 4; k++)
108
+ {
109
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
110
+ }
111
+ }
112
+
113
+ #pragma unroll
114
+
115
+ for (int j = 0; j < 4; j++)
116
+ {
117
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
118
+
119
+ for (int k = 0; k < 4; k++)
120
+ {
121
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
122
+ }
123
+ }
124
+
125
+ #pragma unroll
126
+ for (int j = 0; j < 4; j++)
127
+ {
128
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
129
+ {
130
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
131
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
132
+
133
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
134
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
135
+ }
136
+
137
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
138
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
139
+ }
140
+
141
+ __syncthreads();
142
+
143
+ #pragma unroll
144
+ for (int i = 0; i < n; i++)
145
+ {
146
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
147
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
148
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
149
+ }
150
+
151
+ __syncthreads();
152
+ }
153
+ }
154
+
155
+ __global__ void butterfly_cuda_kernel_32(
156
+ const __nv_bfloat162 *__restrict__ x,
157
+ const __nv_bfloat162 *__restrict__ x_gate,
158
+ const __nv_bfloat16 *__restrict__ d_f_real,
159
+ const __nv_bfloat16 *__restrict__ d_f_imag,
160
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
161
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
162
+ __nv_bfloat162 *__restrict__ out_real,
163
+ __nv_bfloat162 *__restrict__ out_imag,
164
+ uint B,
165
+ uint H,
166
+ int N)
167
+ {
168
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
169
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
170
+ int idx;
171
+
172
+ int shared_offset;
173
+ const int B_Y = blockDim.y;
174
+ const int n = N / B_Y;
175
+
176
+
177
+ __shared__ __nv_bfloat16 x_shared[32 * 64];
178
+ __shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
179
+ __shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
180
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
181
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
182
+ __shared__ float out_real_shared[32 * 64];
183
+ __shared__ float out_imag_shared[32 * 64];
184
+
185
+ // #pragma unroll
186
+ for (int i = 0; i < n; i++)
187
+ {
188
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
189
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
190
+ if(x_gate != nullptr){
191
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
192
+ }else{
193
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
194
+ }
195
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
196
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
197
+
198
+ // #pragma unroll
199
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
200
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
201
+ }
202
+
203
+ __syncthreads();
204
+
205
+ if (threadIdx.y < N / 16)
206
+ {
207
+ float2 tmp_real, tmp_imag;
208
+
209
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
210
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
212
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
213
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
214
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
215
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
216
+
217
+ int t = threadIdx.y * 32;
218
+
219
+ for (int i = 0; i < 2; i++)
220
+ {
221
+ for (int j = 0; j < 2; j++)
222
+ {
223
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
224
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
225
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
226
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
227
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
+ }
229
+ }
230
+
231
+ #pragma unroll
232
+ for (int i = 0; i < 2; i++)
233
+ {
234
+ for (int j = 0; j < 2; j++)
235
+ {
236
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
237
+
238
+ for (int k = 0; k < 2; k++)
239
+ {
240
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
241
+ }
242
+ }
243
+ }
244
+
245
+ #pragma unroll
246
+ for (int i = 0; i < 2; i++)
247
+ {
248
+ for (int j = 0; j < 2; j++)
249
+ {
250
+ wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
251
+
252
+ for (int k = 0; k < 2; k++)
253
+ {
254
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
255
+ }
256
+ }
257
+ }
258
+
259
+ #pragma unroll
260
+ for (int i = 0; i < 2; i++)
261
+ {
262
+ for (int j = 0; j < 2; j++)
263
+ {
264
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
265
+ {
266
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
267
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
268
+ reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
269
+ reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
270
+ }
271
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
272
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
273
+ }
274
+ }
275
+ }
276
+
277
+ __syncthreads();
278
+
279
+ #pragma unroll
280
+ for (int i = 0; i < n; i++)
281
+ {
282
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
283
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
284
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
285
+ }
286
+ }
287
+
288
+ __global__ void butterfly_cuda_kernel_128(
289
+ const __nv_bfloat162 *__restrict__ x,
290
+ const __nv_bfloat162 *__restrict__ x_gate,
291
+ const __nv_bfloat162 *__restrict__ d_f_real,
292
+ const __nv_bfloat162 *__restrict__ d_f_imag,
293
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
294
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
295
+ __nv_bfloat162 *__restrict__ out_real,
296
+ __nv_bfloat162 *__restrict__ out_imag,
297
+ uint B,
298
+ uint H,
299
+ int N)
300
+ {
301
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
302
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
303
+ int idx;
304
+
305
+ int shared_offset;
306
+ const int B_Y = blockDim.y;
307
+ const int n = N / B_Y;
308
+
309
+
310
+ extern __shared__ __nv_bfloat16 shared_real[];
311
+ __nv_bfloat16 *shared_imag = &shared_real[128 * 128];
312
+
313
+
314
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
315
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
316
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
317
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
318
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
319
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
320
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
321
+
322
+ for (int i = 0; i < n; i++)
323
+ {
324
+ for(int j=0; j< 2; j++){
325
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
326
+ reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
327
+ reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
328
+ }
329
+ }
330
+
331
+ __syncthreads();
332
+
333
+
334
+ for (int i = 0; i < 8; i++){
335
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
336
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
337
+ }
338
+
339
+
340
+ __syncthreads();
341
+
342
+
343
+
344
+ for (int i = 0; i < n; i++)
345
+ {
346
+ for(int j=0; j< 2; j++){
347
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
348
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
349
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
350
+ reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
351
+ }
352
+ }
353
+
354
+ __syncthreads();
355
+
356
+
357
+ for (int i = 0; i < 8; i++){
358
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
359
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
360
+ }
361
+
362
+ __syncthreads();
363
+
364
+
365
+ for(int t=0; t< 16; t++){
366
+ for (int i = 0; i < n; i++)
367
+ {
368
+ for(int j=0; j< 2; j++){
369
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
370
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
371
+ if(x_gate != nullptr){
372
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
373
+ }else{
374
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
375
+ }
376
+ }
377
+ }
378
+
379
+
380
+ __syncthreads();
381
+
382
+
383
+ for (int i = 0; i < 8; i++)
384
+ {
385
+ for (int j = 0; j < 8; j++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
388
+ }
389
+ }
390
+
391
+ __syncthreads();
392
+
393
+ #pragma unroll
394
+ for (int j = 0; j < 8; j++)
395
+ {
396
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
397
+
398
+ for (int k = 0; k < 8; k++)
399
+ {
400
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
401
+ }
402
+ }
403
+
404
+ #pragma unroll
405
+
406
+ for (int j = 0; j < 8; j++)
407
+ {
408
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
409
+
410
+ for (int k = 0; k < 8; k++)
411
+ {
412
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
413
+ }
414
+ }
415
+
416
+ float2 tmp_real, tmp_imag;
417
+ #pragma unroll
418
+ for (int j = 0; j < 8; j++)
419
+ {
420
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
421
+ {
422
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
423
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
424
+
425
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
426
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
427
+ }
428
+ }
429
+
430
+ for (int j = 0; j < 8; j++)
431
+ {
432
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
433
+ }
434
+
435
+ __syncthreads();
436
+
437
+ #pragma unroll
438
+ for (int i = 0; i < n; i++)
439
+ {
440
+ for(int j=0; j< 2; j++){
441
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
442
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
443
+ out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
444
+ }
445
+ }
446
+
447
+ __syncthreads();
448
+
449
+
450
+ for (int j = 0; j < 8; j++)
451
+ {
452
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
453
+ }
454
+
455
+ __syncthreads();
456
+
457
+ #pragma unroll
458
+ for (int i = 0; i < n; i++)
459
+ {
460
+ for(int j=0; j< 2; j++){
461
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
+ out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
464
+ }
465
+ }
466
+ }
467
+ }
468
+
469
+
470
+ __global__ void butterfly_cuda_kernel_16(
471
+ const __nv_bfloat162 *__restrict__ x,
472
+ const __nv_bfloat162 *__restrict__ x_gate,
473
+ const __nv_bfloat16 *__restrict__ d_f_real,
474
+ const __nv_bfloat16 *__restrict__ d_f_imag,
475
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
476
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
477
+ __nv_bfloat162 *__restrict__ out_real,
478
+ __nv_bfloat162 *__restrict__ out_imag,
479
+ uint B,
480
+ uint H,
481
+ int N)
482
+ {
483
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
484
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
485
+ int idx;
486
+
487
+ int shared_offset;
488
+ const int B_Y = blockDim.y;
489
+ const int n = N / B_Y;
490
+
491
+
492
+ __shared__ __nv_bfloat16 x_shared[16 * 64];
493
+ __shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
494
+ __shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
495
+ __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
496
+ __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
497
+ __shared__ float out_real_shared[16 * 64];
498
+ __shared__ float out_imag_shared[16 * 64];
499
+
500
+ // #pragma unroll
501
+ for (int i = 0; i < n; i++)
502
+ {
503
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
504
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
505
+ if(x_gate != nullptr){
506
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
507
+ }else{
508
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
509
+ }
510
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
+
513
+ // #pragma unroll
514
+ if(threadIdx.x < 16 ){
515
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
516
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
517
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
518
+ }
519
+ }
520
+
521
+ __syncthreads();
522
+
523
+ if (threadIdx.y < 4)
524
+ {
525
+ float2 tmp_real, tmp_imag;
526
+
527
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
528
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
529
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
530
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
531
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
532
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
533
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
534
+
535
+ wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
536
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
537
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
538
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
539
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
540
+
541
+
542
+
543
+ wmma::fill_fragment(acc_frag_real, 0.0f);
544
+
545
+
546
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
547
+
548
+
549
+
550
+ wmma::fill_fragment(acc_frag_imag, 0.0f);
551
+
552
+
553
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
554
+
555
+
556
+ #pragma unroll
557
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
558
+ {
559
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
560
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
561
+ reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
562
+ reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
563
+ }
564
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
566
+
567
+ }
568
+ __syncthreads();
569
+
570
+ #pragma unroll
571
+ for (int i = 0; i < n; i++)
572
+ {
573
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
575
+ out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
576
+ }
577
+ }
578
+
579
+ std::vector<torch::Tensor> butterfly_bf16_cuda(
580
+ torch::Tensor x,
581
+ torch::Tensor d_f_real,
582
+ torch::Tensor d_f_imag,
583
+ torch::Tensor twiddle_factors_real,
584
+ torch::Tensor twiddle_factors_imag,
585
+ std::optional<at::Tensor> x_gate = std::nullopt
586
+ )
587
+ {
588
+
589
+ uint B = x.size(0);
590
+ uint H = x.size(1);
591
+ // uint m = x.size(1);
592
+
593
+ // const int TILE_SIZE = 16;
594
+ uint N = x.size(2);
595
+ uint M = x.size(3);
596
+ dim3 gridDim;
597
+ dim3 blockDim;
598
+
599
+ gridDim.y = B;
600
+ gridDim.z = H;
601
+
602
+ torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
603
+ torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
604
+
605
+ //set blockDims
606
+ switch(N){
607
+ case 128:
608
+ blockDim.x = 32;
609
+ blockDim.y = 8;
610
+ break;
611
+ default:
612
+ blockDim.x = 32;
613
+ blockDim.y = 4;
614
+ break;
615
+ }
616
+
617
+ //set gridDim.x
618
+ switch(N){
619
+ case 128:
620
+ switch (M){
621
+ case 16384:
622
+ gridDim.x = 128;
623
+ break;
624
+ case 8192:
625
+ gridDim.x = 64;
626
+ break;
627
+ case 4096:
628
+ gridDim.x = 32;
629
+ break;
630
+ default:
631
+ gridDim.x = 256;
632
+ break;
633
+ }
634
+ break;
635
+ default:
636
+ switch (M){
637
+ case 16384:
638
+ gridDim.x = 256;
639
+ break;
640
+ case 8192:
641
+ gridDim.x = 128;
642
+ break;
643
+ case 4096:
644
+ gridDim.x = 64;
645
+ break;
646
+ default:
647
+ gridDim.x = 512;
648
+ break;
649
+ }
650
+ break;
651
+ }
652
+
653
+ switch (N)
654
+ {
655
+ case 16:
656
+ butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
657
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
658
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
659
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
660
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
661
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
662
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
663
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
664
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
665
+ B,
666
+ H,
667
+ N);
668
+ break;
669
+ case 32:
670
+ butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
671
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
672
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
673
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
674
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
678
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
679
+ B,
680
+ H,
681
+ N);
682
+ break;
683
+
684
+ case 64:
685
+ gridDim.z = H / 16;
686
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
687
+
688
+ butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
689
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
690
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
691
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
692
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
693
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
694
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
695
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
696
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
697
+ B,
698
+ H,
699
+ N);
700
+ break;
701
+ case 128:
702
+ gridDim.z = H / 16;
703
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
704
+
705
+ butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
706
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
707
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
708
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
709
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
710
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
711
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
712
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
713
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
714
+ B,
715
+ H,
716
+ N);
717
+ break;
718
+
719
+ default:
720
+ printf("Not yet implemented \n");
721
+ break;
722
+ }
723
+
724
+ return {out_real, out_imag};
725
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ __global__ void butterfly_ifft_cuda_kernel_64(
15
+ const __half2 *__restrict__ x_real,
16
+ const __half2 *__restrict__ x_imag,
17
+ const complex_half_t *__restrict__ d_f,
18
+ const __half2 *__restrict__ twiddle_factors_real,
19
+ const __half2 *__restrict__ twiddle_factors_imag,
20
+ __half2 *__restrict__ out_real,
21
+ __half2 *__restrict__ out_gate,
22
+ uint B,
23
+ uint H,
24
+ int N)
25
+ {
26
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
27
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
28
+ int idx;
29
+ int shared_offset;
30
+ const int B_Y = blockDim.y;
31
+ const int n = N / B_Y;
32
+
33
+ extern __shared__ half x_real_shared[];
34
+ half *x_imag_shared = &x_real_shared[N * N];
35
+ half *d_f_real = &x_imag_shared[N * N];
36
+ half *d_f_imag = &d_f_real[N * N];
37
+ half *twiddles_real_shared = &d_f_imag[N * N];
38
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
39
+ half *out_real_shared = &twiddles_imag_shared[N * N];
40
+
41
+ half tmp_real, tmp_imag;
42
+
43
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
44
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
45
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
46
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
47
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
48
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
49
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
50
+
51
+ // #pragma unroll
52
+ for (int i = 0; i < n; i++)
53
+ {
54
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
55
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
56
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
57
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
58
+
59
+ // #pragma unroll
60
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
61
+ d_f_real[shared_offset] = d_f[shared_offset].real();
62
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
63
+
64
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
65
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
66
+ }
67
+
68
+ __syncthreads();
69
+
70
+ for (int i = 0; i < 4; i++)
71
+ {
72
+ #pragma unroll
73
+ for (int j = 0; j < 4; j++)
74
+ {
75
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
76
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
77
+ }
78
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
79
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
80
+ }
81
+
82
+ for (int t = 0; t < 16; t++)
83
+ {
84
+
85
+ for (int i = 0; i < n; i++)
86
+ {
87
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
88
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
89
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
90
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
91
+ }
92
+
93
+ __syncthreads();
94
+
95
+ for (int i = 0; i < 4; i++)
96
+ {
97
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
98
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
99
+ }
100
+
101
+ for (int j = 0; j < 4; j++)
102
+ {
103
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
104
+ {
105
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
106
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
107
+ b_frag_real[j].x[k] = tmp_real;
108
+ b_frag_imag[j].x[k] = tmp_imag;
109
+ }
110
+ }
111
+
112
+ for (int i = 0; i < 4; i++)
113
+ {
114
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
115
+
116
+ // bd
117
+ #pragma unroll
118
+ for (int k = 0; k < 4; k++)
119
+ {
120
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
121
+ }
122
+
123
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
124
+ {
125
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
126
+ }
127
+ }
128
+
129
+ for (int i = 0; i < 4; i++)
130
+ {
131
+ // ac - bd
132
+ #pragma unroll
133
+ for (int k = 0; k < 4; k++)
134
+ {
135
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
136
+ }
137
+ }
138
+
139
+ #pragma unroll
140
+ for (int i = 0; i < 4; i++)
141
+ {
142
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
143
+ }
144
+
145
+ __syncthreads();
146
+
147
+ #pragma unroll
148
+ for (int i = 0; i < n; i++)
149
+ {
150
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
151
+ if(out_gate != nullptr){
152
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
153
+ }
154
+ else{
155
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
156
+ }
157
+ }
158
+
159
+ __syncthreads();
160
+ }
161
+ }
162
+
163
+ __global__ void butterfly_ifft_cuda_kernel_32(
164
+ const __half2 *__restrict__ x_real,
165
+ const __half2 *__restrict__ x_imag,
166
+ const complex_half_t *__restrict__ d_f,
167
+ const __half2 *__restrict__ twiddle_factors_real,
168
+ const __half2 *__restrict__ twiddle_factors_imag,
169
+ __half2 *__restrict__ out_real,
170
+ __half2 *__restrict__ out_gate,
171
+ uint B,
172
+ uint H,
173
+ int N)
174
+ {
175
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
176
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
177
+ int idx;
178
+ int shared_offset;
179
+ const int B_Y = blockDim.y;
180
+ const int n = N / B_Y;
181
+
182
+ __shared__ half x_real_shared[32 * 64];
183
+ __shared__ half x_imag_shared[32 * 64];
184
+ __shared__ half d_f_real[32 * 32];
185
+ __shared__ half d_f_imag[32 * 32];
186
+ __shared__ half twiddles_real_shared[32 * 64];
187
+ __shared__ half twiddles_imag_shared[32 * 64];
188
+ __shared__ half out_real_shared[32 * 64];
189
+
190
+ // #pragma unroll
191
+ for (int i = 0; i < n; i++)
192
+ {
193
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
194
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
195
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
196
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
197
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
198
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
199
+
200
+ // #pragma unroll
201
+ d_f_real[shared_offset] = d_f[shared_offset].real();
202
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
203
+ }
204
+
205
+ __syncthreads();
206
+
207
+ if (threadIdx.y < N / 16)
208
+ {
209
+ half tmp_real, tmp_imag;
210
+
211
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
212
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
213
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
214
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
215
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
216
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
217
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
218
+
219
+ int t = threadIdx.y * 32;
220
+
221
+ for (int i = 0; i < 2; i++)
222
+ {
223
+ for (int j = 0; j < 2; j++)
224
+ {
225
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
226
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
227
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
228
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
229
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
230
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
231
+ }
232
+ }
233
+
234
+ for (int i = 0; i < 2; i++)
235
+ {
236
+ for (int j = 0; j < 2; j++)
237
+ {
238
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
239
+ {
240
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
241
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
242
+ b_frag_real[i][j].x[k] = tmp_real;
243
+ b_frag_imag[i][j].x[k] = tmp_imag;
244
+ }
245
+ }
246
+ }
247
+
248
+ for (int i = 0; i < 2; i++)
249
+ {
250
+ for (int j = 0; j < 2; j++)
251
+ {
252
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
253
+
254
+ // bd
255
+ for (int k = 0; k < 2; k++)
256
+ {
257
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
258
+ }
259
+
260
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
261
+ {
262
+ acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
263
+ }
264
+ }
265
+ }
266
+
267
+ for (int i = 0; i < 2; i++)
268
+ {
269
+ for (int j = 0; j < 2; j++)
270
+ {
271
+ // ac - bd
272
+ for (int k = 0; k < 2; k++)
273
+ {
274
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
275
+ }
276
+ }
277
+ }
278
+
279
+ for (int i = 0; i < 2; i++)
280
+ {
281
+ for (int j = 0; j < 2; j++)
282
+ {
283
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
284
+ }
285
+ }
286
+ }
287
+
288
+ __syncthreads();
289
+
290
+ #pragma unroll
291
+ for (int i = 0; i < n; i++)
292
+ {
293
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
294
+ if(out_gate != nullptr){
295
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
296
+ }
297
+ else{
298
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
299
+ }
300
+ }
301
+ }
302
+
303
+
304
+ __global__ void butterfly_ifft_cuda_kernel_128(
305
+ const __half2 *__restrict__ x_real,
306
+ const __half2 *__restrict__ x_imag,
307
+ const complex_half_t *__restrict__ d_f,
308
+ const __half2 *__restrict__ twiddle_factors_real,
309
+ const __half2 *__restrict__ twiddle_factors_imag,
310
+ __half2 *__restrict__ out_real,
311
+ __half2 *__restrict__ out_gate,
312
+ uint B,
313
+ uint H,
314
+ int N)
315
+ {
316
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
317
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
318
+ int idx;
319
+ int shared_offset;
320
+
321
+ const int B_Y = 8;
322
+ const int n = 16;
323
+
324
+ extern __shared__ half real_shared[];
325
+ half *imag_shared = &real_shared[128 * 128];
326
+ half *real_shared_2 = &imag_shared[128 * 128];
327
+ half *imag_shared_2 = &real_shared_2[128 * 128];
328
+
329
+ __half2 tmp_real, tmp_imag;
330
+
331
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
332
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
333
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
334
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
335
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
336
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
337
+
338
+ for (int i = 0; i < n; i++)
339
+ {
340
+ for(int j=0; j< 4; j++){
341
+ shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
342
+ real_shared_2[shared_offset] = d_f[shared_offset].real();
343
+ imag_shared_2[shared_offset] = d_f[shared_offset].imag();
344
+ }
345
+ }
346
+
347
+
348
+ __syncthreads();
349
+
350
+ for (int i = 0; i < n; i++)
351
+ {
352
+ for(int j=0; j< 2; j++){
353
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
354
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
355
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
356
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
357
+ }
358
+ }
359
+
360
+ __syncthreads();
361
+
362
+
363
+ for (int i = 0; i < 8; i++){
364
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
365
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
366
+ }
367
+
368
+ __syncthreads();
369
+
370
+ for (int t = 0; t < 16; t++)
371
+ {
372
+
373
+ for (int i = 0; i < n; i++)
374
+ {
375
+ for(int j=0; j< 2; j++){
376
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
377
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
378
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
379
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
380
+ }
381
+ }
382
+
383
+ __syncthreads();
384
+
385
+ for (int i = 0; i < 8; i++)
386
+ {
387
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
388
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
389
+ }
390
+
391
+
392
+ for (int j = 0; j < 8; j++)
393
+ {
394
+ for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
395
+ {
396
+ tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
397
+ __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
398
+ tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
399
+ __hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
400
+ reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
401
+ reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
402
+ }
403
+ }
404
+
405
+ for (int i = 0; i < 8; i++){
406
+ for (int j = 0; j < 8; j++){
407
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
408
+ }
409
+ }
410
+
411
+ __syncthreads();
412
+
413
+ for (int i = 0; i < 8; i++)
414
+ {
415
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
416
+
417
+ // bd
418
+ #pragma unroll
419
+ for (int k = 0; k < 8; k++)
420
+ {
421
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
422
+ }
423
+
424
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
425
+ {
426
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
427
+ }
428
+ }
429
+
430
+
431
+ for (int i = 0; i < 8; i++){
432
+ for (int j = 0; j < 8; j++){
433
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
434
+ }
435
+ }
436
+
437
+ __syncthreads();
438
+
439
+ for (int i = 0; i < 8; i++)
440
+ {
441
+ // ac - bd
442
+ #pragma unroll
443
+ for (int k = 0; k < 8; k++)
444
+ {
445
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
446
+ }
447
+ }
448
+
449
+ #pragma unroll
450
+ for (int i = 0; i < 8; i++)
451
+ {
452
+ wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
453
+ }
454
+
455
+ __syncthreads();
456
+
457
+ #pragma unroll
458
+ for (int i = 0; i < n; i++)
459
+ {
460
+ for(int j=0; j< 2; j++){
461
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
462
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
463
+ if(out_gate != nullptr){
464
+ out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
465
+ }
466
+ else{
467
+ out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
468
+ }
469
+ }
470
+ }
471
+
472
+ __syncthreads();
473
+ }
474
+ }
475
+
476
+ __global__ void butterfly_ifft_cuda_kernel_16(
477
+ const __half2 *__restrict__ x_real,
478
+ const __half2 *__restrict__ x_imag,
479
+ const complex_half_t *__restrict__ d_f,
480
+ const __half2 *__restrict__ twiddle_factors_real,
481
+ const __half2 *__restrict__ twiddle_factors_imag,
482
+ __half2 *__restrict__ out_real,
483
+ __half2 *__restrict__ out_gate,
484
+ uint B,
485
+ uint H,
486
+ int N)
487
+ {
488
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
489
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
490
+ int idx;
491
+ int shared_offset;
492
+ const int B_Y = blockDim.y;
493
+ const int n = N / B_Y;
494
+
495
+ __shared__ half x_real_shared[16 * 64];
496
+ __shared__ half x_imag_shared[16 * 64];
497
+ __shared__ half d_f_real[16 * 16];
498
+ __shared__ half d_f_imag[16 * 16];
499
+ __shared__ half twiddles_real_shared[16 * 64];
500
+ __shared__ half twiddles_imag_shared[16 * 64];
501
+ __shared__ half out_real_shared[16 * 64];
502
+
503
+ // #pragma unroll
504
+ for (int i = 0; i < n; i++)
505
+ {
506
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
507
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
508
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
509
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
510
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
511
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
512
+
513
+ if(threadIdx.x < 16 ){
514
+ shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
515
+ d_f_real[shared_offset] = d_f[shared_offset].real();
516
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
517
+ }
518
+ }
519
+
520
+ __syncthreads();
521
+
522
+ //check if it is better to have one warp do all the multiplication or split between warps
523
+ if (threadIdx.y < 4)
524
+ {
525
+ half tmp_real, tmp_imag;
526
+
527
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
528
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
529
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
530
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
531
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
532
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
533
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
534
+
535
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
536
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
537
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
538
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
539
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
540
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
541
+
542
+
543
+
544
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
545
+ {
546
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
547
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
548
+ b_frag_real.x[k] = tmp_real;
549
+ b_frag_imag.x[k] = tmp_imag;
550
+ }
551
+
552
+
553
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
554
+
555
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
556
+
557
+ for(int k=0; k< acc_frag_real.num_elements; k++){
558
+ acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
559
+ }
560
+
561
+
562
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
563
+
564
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
565
+
566
+ }
567
+
568
+ __syncthreads();
569
+
570
+ #pragma unroll
571
+ for (int i = 0; i < n; i++)
572
+ {
573
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
574
+ if(out_gate != nullptr){
575
+ out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
576
+ }
577
+ else{
578
+ out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
579
+ }
580
+ }
581
+ }
582
+
583
+ torch::Tensor butterfly_ifft_cuda(
584
+ torch::Tensor x_real,
585
+ torch::Tensor x_imag,
586
+ torch::Tensor d_f,
587
+ torch::Tensor twiddle_factors_real,
588
+ torch::Tensor twiddle_factors_imag,
589
+ std::optional<at::Tensor> out_gate = std::nullopt)
590
+ {
591
+
592
+ uint B = x_real.size(0);
593
+ uint H = x_real.size(1);
594
+ // uint m = x.size(1);
595
+
596
+ // const int TILE_SIZE = 16;
597
+
598
+ dim3 gridDim;
599
+ dim3 blockDim;
600
+
601
+ uint N = x_real.size(2);
602
+ uint M = x_real.size(3);
603
+ gridDim.y = B;
604
+
605
+ blockDim.x = 32;
606
+ blockDim.y = 4;
607
+
608
+ torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
609
+ gridDim.z = H;
610
+
611
+ //set blockDims
612
+ switch(N){
613
+ case 128:
614
+ blockDim.x = 32;
615
+ blockDim.y = 8;
616
+ break;
617
+ default:
618
+ blockDim.x = 32;
619
+ blockDim.y = 4;
620
+ break;
621
+ }
622
+
623
+ //set gridDim.x
624
+ switch(N){
625
+ case 128:
626
+ switch (M){
627
+ case 16384:
628
+ gridDim.x = 128;
629
+ break;
630
+ case 8192:
631
+ gridDim.x = 64;
632
+ break;
633
+ case 4096:
634
+ gridDim.x = 32;
635
+ break;
636
+ default:
637
+ gridDim.x = 256;
638
+ break;
639
+ }
640
+ break;
641
+ default:
642
+ switch (M){
643
+ case 16384:
644
+ gridDim.x = 256;
645
+ break;
646
+ case 8192:
647
+ gridDim.x = 128;
648
+ break;
649
+ case 4096:
650
+ gridDim.x = 64;
651
+ break;
652
+ default:
653
+ gridDim.x = 512;
654
+ break;
655
+ }
656
+ break;
657
+ }
658
+
659
+ switch (N)
660
+ {
661
+ case 16:
662
+ butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
663
+ static_cast<__half2 *>(x_real.data_ptr()),
664
+ static_cast<__half2 *>(x_imag.data_ptr()),
665
+ static_cast<complex_half_t *>(d_f.data_ptr()),
666
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
667
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
668
+ static_cast<__half2 *>(out.data_ptr()),
669
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
670
+ B,
671
+ H,
672
+ N);
673
+ break;
674
+ case 32:
675
+ butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
676
+ static_cast<__half2 *>(x_real.data_ptr()),
677
+ static_cast<__half2 *>(x_imag.data_ptr()),
678
+ static_cast<complex_half_t *>(d_f.data_ptr()),
679
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
680
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
681
+ static_cast<__half2 *>(out.data_ptr()),
682
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
683
+ B,
684
+ H,
685
+ N);
686
+ break;
687
+ case 64:
688
+ gridDim.z = H / 16;
689
+ cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
690
+ butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
691
+ static_cast<__half2 *>(x_real.data_ptr()),
692
+ static_cast<__half2 *>(x_imag.data_ptr()),
693
+ static_cast<complex_half_t *>(d_f.data_ptr()),
694
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
695
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
696
+ static_cast<__half2 *>(out.data_ptr()),
697
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
698
+ B,
699
+ H,
700
+ N);
701
+ break;
702
+
703
+ case 128:
704
+ gridDim.z = H / 16;
705
+ cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
706
+ butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
707
+ static_cast<__half2 *>(x_real.data_ptr()),
708
+ static_cast<__half2 *>(x_imag.data_ptr()),
709
+ static_cast<complex_half_t *>(d_f.data_ptr()),
710
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
711
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
712
+ static_cast<__half2 *>(out.data_ptr()),
713
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
714
+ B,
715
+ H,
716
+ N);
717
+ break;
718
+ default:
719
+ printf("Not implemented\n");
720
+ }
721
+
722
+ return out;
723
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include <cuda_runtime.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+ __global__ void butterfly_ifft_bf16_cuda_kernel_64(
16
+ const __nv_bfloat162 *__restrict__ x_real,
17
+ const __nv_bfloat162 *__restrict__ x_imag,
18
+ const __nv_bfloat162 *__restrict__ d_f_real,
19
+ const __nv_bfloat162 *__restrict__ d_f_imag,
20
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
+ __nv_bfloat162 *__restrict__ out_real,
23
+ __nv_bfloat162 *__restrict__ out_gate,
24
+ uint B,
25
+ uint H,
26
+ int N)
27
+ {
28
+ const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
29
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
30
+ int idx;
31
+ int shared_offset;
32
+ const int B_Y = blockDim.y;
33
+ const int n = N / B_Y;
34
+
35
+ extern __shared__ __nv_bfloat16 x_real_shared[];
36
+ __nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
37
+ __nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
38
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
39
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
40
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
41
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
42
+
43
+ __nv_bfloat16 tmp_real, tmp_imag;
44
+
45
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
46
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
47
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
48
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
49
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
50
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
51
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
52
+
53
+ // #pragma unroll
54
+ for (int i = 0; i < n; i++)
55
+ {
56
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
57
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
58
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
59
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
60
+
61
+ // #pragma unroll
62
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
63
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
64
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
65
+ }
66
+
67
+ __syncthreads();
68
+
69
+ for (int i = 0; i < 4; i++)
70
+ {
71
+ #pragma unroll
72
+ for (int j = 0; j < 4; j++)
73
+ {
74
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
75
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
76
+ }
77
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
78
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
79
+ }
80
+
81
+ for (int t = 0; t < 16; t++)
82
+ {
83
+
84
+ for (int i = 0; i < n; i++)
85
+ {
86
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
87
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
88
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
89
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
90
+ }
91
+
92
+ __syncthreads();
93
+
94
+ for (int i = 0; i < 4; i++)
95
+ {
96
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
97
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
98
+ }
99
+
100
+ for (int j = 0; j < 4; j++)
101
+ {
102
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
103
+ {
104
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
105
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
106
+ b_frag_real[j].x[k] = tmp_real;
107
+ b_frag_imag[j].x[k] = tmp_imag;
108
+ }
109
+ }
110
+
111
+ for (int i = 0; i < 4; i++)
112
+ {
113
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
114
+
115
+ // bd
116
+ #pragma unroll
117
+ for (int k = 0; k < 4; k++)
118
+ {
119
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
120
+ }
121
+
122
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
123
+ {
124
+ acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
125
+ }
126
+ }
127
+
128
+ for (int i = 0; i < 4; i++)
129
+ {
130
+ // ac - bd
131
+ #pragma unroll
132
+ for (int k = 0; k < 4; k++)
133
+ {
134
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
135
+ }
136
+ }
137
+
138
+ #pragma unroll
139
+ for (int i = 0; i < 4; i++)
140
+ {
141
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
142
+ }
143
+
144
+ __syncthreads();
145
+
146
+ #pragma unroll
147
+ for (int i = 0; i < n; i++)
148
+ {
149
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
150
+ if(out_gate != nullptr){
151
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
152
+ }else{
153
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
154
+ }
155
+ }
156
+
157
+ __syncthreads();
158
+ }
159
+ }
160
+
161
+ __global__ void butterfly_ifft_bf16_cuda_kernel_32(
162
+ const __nv_bfloat162 *__restrict__ x_real,
163
+ const __nv_bfloat162 *__restrict__ x_imag,
164
+ const __nv_bfloat16 *__restrict__ d_f_real,
165
+ const __nv_bfloat16 *__restrict__ d_f_imag,
166
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
167
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
168
+ __nv_bfloat162 *__restrict__ out_real,
169
+ __nv_bfloat162 *__restrict__ out_gate,
170
+ uint B,
171
+ uint H,
172
+ int N)
173
+ {
174
+ const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
175
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
176
+ int idx;
177
+ int shared_offset;
178
+ const int B_Y = blockDim.y;
179
+ const int n = N / B_Y;
180
+
181
+ __shared__ __nv_bfloat16 x_real_shared[32 * 64];
182
+ __shared__ __nv_bfloat16 x_imag_shared[32 * 64];
183
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
184
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
185
+ __shared__ float out_real_shared[32 * 64];
186
+
187
+ // #pragma unroll
188
+ for (int i = 0; i < n; i++)
189
+ {
190
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
191
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
192
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
193
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
194
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
195
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
196
+ }
197
+
198
+ __syncthreads();
199
+
200
+ if (threadIdx.y < N / 16)
201
+ {
202
+ __nv_bfloat16 tmp_real, tmp_imag;
203
+
204
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
205
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
206
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
207
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
208
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
209
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
210
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
211
+
212
+ int t = threadIdx.y * 32;
213
+
214
+ for (int i = 0; i < 2; i++)
215
+ {
216
+ for (int j = 0; j < 2; j++)
217
+ {
218
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
219
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
220
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
221
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
222
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
223
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
224
+ }
225
+ }
226
+
227
+ for (int i = 0; i < 2; i++)
228
+ {
229
+ for (int j = 0; j < 2; j++)
230
+ {
231
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
232
+ {
233
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
234
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
235
+ b_frag_real[i][j].x[k] = tmp_real;
236
+ b_frag_imag[i][j].x[k] = tmp_imag;
237
+ }
238
+ }
239
+ }
240
+
241
+ for (int i = 0; i < 2; i++)
242
+ {
243
+ for (int j = 0; j < 2; j++)
244
+ {
245
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
246
+
247
+ // bd
248
+ for (int k = 0; k < 2; k++)
249
+ {
250
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
251
+ }
252
+
253
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
254
+ {
255
+ acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
256
+ }
257
+ }
258
+ }
259
+
260
+ for (int i = 0; i < 2; i++)
261
+ {
262
+ for (int j = 0; j < 2; j++)
263
+ {
264
+ // ac - bd
265
+ for (int k = 0; k < 2; k++)
266
+ {
267
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
268
+ }
269
+ }
270
+ }
271
+
272
+ for (int i = 0; i < 2; i++)
273
+ {
274
+ for (int j = 0; j < 2; j++)
275
+ {
276
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
277
+ }
278
+ }
279
+ }
280
+
281
+ __syncthreads();
282
+
283
+ #pragma unroll
284
+ for (int i = 0; i < n; i++)
285
+ {
286
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
287
+ if(out_gate != nullptr){
288
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
289
+ }else{
290
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
291
+ }
292
+ }
293
+ }
294
+
295
+
296
+ __global__ void butterfly_ifft_bf16_cuda_kernel_128(
297
+ const __nv_bfloat162 *__restrict__ x_real,
298
+ const __nv_bfloat162 *__restrict__ x_imag,
299
+ const __nv_bfloat162 *__restrict__ d_f_real,
300
+ const __nv_bfloat162 *__restrict__ d_f_imag,
301
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
302
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
303
+ __nv_bfloat162 *__restrict__ out_real,
304
+ __nv_bfloat162 *__restrict__ out_gate,
305
+ uint B,
306
+ uint H,
307
+ int N)
308
+ {
309
+ const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
310
+ const int tw_offset = blockIdx.x * 64 + threadIdx.x;
311
+ int idx;
312
+ int shared_offset;
313
+ const int B_Y = blockDim.y;
314
+ const int n = N / B_Y;
315
+
316
+ extern __shared__ __nv_bfloat16 real_shared[];
317
+ __nv_bfloat16 *imag_shared = &real_shared[128 * 128];
318
+ __nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
319
+ __nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
320
+
321
+ __nv_bfloat16 tmp_real, tmp_imag;
322
+
323
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
324
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
325
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
326
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
327
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
328
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
329
+
330
+ for (int i = 0; i < n; i++)
331
+ {
332
+ for(int j=0; j< 2; j++){
333
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
334
+ reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
335
+ reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
336
+ }
337
+ }
338
+
339
+ for (int i = 0; i < n; i++)
340
+ {
341
+ for(int j=0; j< 2; j++){
342
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
343
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
344
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
345
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
346
+ }
347
+ }
348
+
349
+ __syncthreads();
350
+
351
+
352
+ for (int i = 0; i < 8; i++){
353
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
354
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
355
+ }
356
+
357
+ __syncthreads();
358
+
359
+ for (int t = 0; t < 16; t++)
360
+ {
361
+ for (int i = 0; i < 8; i++){
362
+ for (int j = 0; j < 8; j++){
363
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
364
+ }
365
+ }
366
+
367
+ for (int i = 0; i < n; i++)
368
+ {
369
+ for(int j=0; j< 2; j++){
370
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
371
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
372
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
373
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
374
+ }
375
+ }
376
+
377
+ __syncthreads();
378
+
379
+ for (int i = 0; i < 8; i++)
380
+ {
381
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
382
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
383
+ }
384
+
385
+
386
+ for (int j = 0; j < 8; j++)
387
+ {
388
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
389
+ {
390
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
391
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
392
+ b_frag_real[j].x[k] = tmp_real;
393
+ b_frag_imag[j].x[k] = tmp_imag;
394
+ }
395
+ }
396
+
397
+ for (int i = 0; i < 8; i++)
398
+ {
399
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
400
+
401
+ // bd
402
+ #pragma unroll
403
+ for (int k = 0; k < 8; k++)
404
+ {
405
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
406
+ }
407
+
408
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
409
+ {
410
+ acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
411
+ }
412
+ }
413
+
414
+ for (int i = 0; i < 8; i++){
415
+ for (int j = 0; j < 8; j++){
416
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
417
+ }
418
+ }
419
+
420
+ for (int i = 0; i < 8; i++)
421
+ {
422
+ // ac - bd
423
+ #pragma unroll
424
+ for (int k = 0; k < 8; k++)
425
+ {
426
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
427
+ }
428
+ }
429
+
430
+ #pragma unroll
431
+ for (int i = 0; i < 8; i++)
432
+ {
433
+ //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
434
+ wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
435
+ }
436
+
437
+ __syncthreads();
438
+
439
+ #pragma unroll
440
+ for (int i = 0; i < n; i++)
441
+ {
442
+ for(int j=0; j< 2; j++){
443
+ idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
444
+ shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
445
+ if(out_gate != nullptr){
446
+ out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
447
+ }else{
448
+ out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
449
+ }
450
+ }
451
+ }
452
+
453
+ __syncthreads();
454
+ }
455
+ }
456
+
457
+ __global__ void butterfly_ifft_bf16_cuda_kernel_16(
458
+ const __nv_bfloat162 *__restrict__ x_real,
459
+ const __nv_bfloat162 *__restrict__ x_imag,
460
+ const __nv_bfloat16 *__restrict__ d_f_real,
461
+ const __nv_bfloat16 *__restrict__ d_f_imag,
462
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
463
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
464
+ __nv_bfloat162 *__restrict__ out_real,
465
+ __nv_bfloat162 *__restrict__ out_gate,
466
+ uint B,
467
+ uint H,
468
+ int N)
469
+ {
470
+ const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
471
+ const int tw_offset = blockIdx.x * 32 + threadIdx.x;
472
+ int idx;
473
+ int shared_offset;
474
+ const int B_Y = blockDim.y;
475
+ const int n = N / B_Y;
476
+
477
+ __shared__ __nv_bfloat16 x_real_shared[16 * 64];
478
+ __shared__ __nv_bfloat16 x_imag_shared[16 * 64];
479
+ __shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
480
+ __shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
481
+ __shared__ float out_real_shared[16 * 64];
482
+
483
+ // #pragma unroll
484
+ for (int i = 0; i < n; i++)
485
+ {
486
+ idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
487
+ shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
488
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
489
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
490
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
491
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
492
+ }
493
+
494
+ __syncthreads();
495
+
496
+ if (threadIdx.y < 4)
497
+ {
498
+ __nv_bfloat16 tmp_real, tmp_imag;
499
+
500
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
501
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
502
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
503
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
504
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
505
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
506
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
507
+
508
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
509
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
510
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
511
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
512
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
513
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
514
+
515
+
516
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
517
+ {
518
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
519
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
520
+ b_frag_real.x[k] = tmp_real;
521
+ b_frag_imag.x[k] = tmp_imag;
522
+ }
523
+
524
+
525
+
526
+ wmma::fill_fragment(acc_frag_real, 0.0f);
527
+
528
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
529
+
530
+ for(int k=0; k< acc_frag_real.num_elements; k++){
531
+ acc_frag_real.x[k] = - acc_frag_real.x[k];
532
+ }
533
+
534
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
535
+
536
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
537
+
538
+ }
539
+
540
+ __syncthreads();
541
+
542
+ #pragma unroll
543
+ for (int i = 0; i < n; i++)
544
+ {
545
+ idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
546
+ if(out_gate != nullptr){
547
+ out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
548
+ }else{
549
+ out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
550
+ }
551
+ }
552
+ }
553
+
554
+
555
+ torch::Tensor butterfly_ifft_bf16_cuda(
556
+ torch::Tensor x_real,
557
+ torch::Tensor x_imag,
558
+ torch::Tensor d_f_real,
559
+ torch::Tensor d_f_imag,
560
+ torch::Tensor twiddle_factors_real,
561
+ torch::Tensor twiddle_factors_imag,
562
+ std::optional<at::Tensor> out_gate = std::nullopt
563
+ )
564
+ {
565
+
566
+ uint B = x_real.size(0);
567
+ uint H = x_real.size(1);
568
+ // uint m = x.size(1);
569
+
570
+ // const int TILE_SIZE = 16;
571
+
572
+ dim3 gridDim;
573
+ dim3 blockDim;
574
+
575
+ uint N = x_real.size(2);
576
+ uint M = x_real.size(3);
577
+ gridDim.y = B;
578
+
579
+ blockDim.x = 32;
580
+ blockDim.y = 4;
581
+
582
+ torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
583
+
584
+
585
+ //set blockDims
586
+ switch(N){
587
+ case 128:
588
+ blockDim.x = 32;
589
+ blockDim.y = 8;
590
+ break;
591
+ default:
592
+ blockDim.x = 32;
593
+ blockDim.y = 4;
594
+ break;
595
+ }
596
+
597
+ //set gridDim.x
598
+ switch(N){
599
+ case 128:
600
+ switch (M){
601
+ case 16384:
602
+ gridDim.x = 128;
603
+ break;
604
+ case 8192:
605
+ gridDim.x = 64;
606
+ break;
607
+ case 4096:
608
+ gridDim.x = 32;
609
+ break;
610
+ default:
611
+ gridDim.x = 256;
612
+ break;
613
+ }
614
+ break;
615
+ default:
616
+ switch (M){
617
+ case 16384:
618
+ gridDim.x = 256;
619
+ break;
620
+ case 8192:
621
+ gridDim.x = 128;
622
+ break;
623
+ case 4096:
624
+ gridDim.x = 64;
625
+ break;
626
+ default:
627
+ gridDim.x = 512;
628
+ break;
629
+ }
630
+ break;
631
+ }
632
+
633
+
634
+ switch (N)
635
+ {
636
+ case 16:
637
+ gridDim.z = H;
638
+ butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
639
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
640
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
641
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
642
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
643
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
644
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
645
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
646
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
647
+ B,
648
+ H,
649
+ N);
650
+ break;
651
+
652
+ case 32:
653
+ gridDim.z = H;
654
+ butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
655
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
656
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
657
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
658
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
659
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
660
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
661
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
662
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
663
+ B,
664
+ H,
665
+ N);
666
+ break;
667
+ case 64:
668
+ gridDim.z = H / 16;
669
+ cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
670
+ butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
671
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
672
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
673
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
674
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
676
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
677
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
678
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
679
+ B,
680
+ H,
681
+ N);
682
+ break;
683
+
684
+ case 128:
685
+ gridDim.z = H / 16;
686
+ cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
687
+ butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
688
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
689
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
690
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
691
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
692
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
693
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
694
+ static_cast<__nv_bfloat162 *>(out.data_ptr()),
695
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
696
+ B,
697
+ H,
698
+ N);
699
+ break;
700
+ default:
701
+ printf("Not implemented\n");
702
+ }
703
+
704
+ return out;
705
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda.cu ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cmath>
9
+ #include <cuda_fp16.h>
10
+ #include <cuda_bf16.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+ template <int K>
16
+ __global__ void butterfly_padded_cuda_kernel_64(
17
+ const __half2 *__restrict__ x,
18
+ const __half2 *__restrict__ x_gate,
19
+ const complex_half_t *__restrict__ d_f,
20
+ const __half2 *__restrict__ twiddle_factors_real,
21
+ const __half2 *__restrict__ twiddle_factors_imag,
22
+ __half2 *__restrict__ out_real,
23
+ __half2 *__restrict__ out_imag,
24
+ uint B,
25
+ uint H,
26
+ int M)
27
+ {
28
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
29
+ const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
30
+ const int out_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x;
31
+ int idx;
32
+ int t_offset;
33
+ int out_t_offset;
34
+ int shared_offset;
35
+ const int N = 64;
36
+
37
+ extern __shared__ half x_shared[];
38
+ half *d_f_real = &x_shared[K * 16 * N];
39
+ half *d_f_imag = &d_f_real[N * N];
40
+ half *twiddles_real_shared = &d_f_imag[N * N];
41
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
42
+ half *out_real_shared = &twiddles_imag_shared[N * N];
43
+ half *out_imag_shared = &out_real_shared[N * N];
44
+
45
+ // #pragma unroll
46
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
47
+ {
48
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
49
+ shared_offset = i * 32 + threadIdx.x;
50
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
51
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
52
+
53
+ // #pragma unroll
54
+ shared_offset = i * 64 + threadIdx.x;
55
+ d_f_real[shared_offset] = d_f[shared_offset].real();
56
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
57
+
58
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
59
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
60
+ }
61
+
62
+ __half2 tmp_real, tmp_imag;
63
+
64
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
65
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
66
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
67
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
68
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][4];
69
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
70
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
71
+
72
+ __syncthreads();
73
+
74
+ for (int i = 0; i < 4; i++)
75
+ {
76
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
77
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
78
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
79
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
80
+ }
81
+
82
+ for (int t = 0; t < 16; t++)
83
+ {
84
+ t_offset = t * M/2;
85
+ out_t_offset = t * 64 * 32 * gridDim.x;
86
+
87
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
88
+ {
89
+ if(i < K * 16){
90
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
91
+ shared_offset = i * 32 + threadIdx.x;
92
+ if(x_gate != nullptr){
93
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2half2_rn(0.0f, 0.0f);
94
+ }
95
+ else{
96
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2half2_rn(0.0f, 0.0f);
97
+ }
98
+ }
99
+ }
100
+
101
+ __syncthreads();
102
+
103
+ for (int i = 0; i < K; i++)
104
+ {
105
+ for (int j = 0; j < 4; j++)
106
+ {
107
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
108
+ }
109
+ }
110
+
111
+ #pragma unroll
112
+ for (int j = 0; j < 4; j++)
113
+ {
114
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
115
+
116
+ for (int k = 0; k < K; k++)
117
+ {
118
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
119
+ }
120
+ }
121
+
122
+ #pragma unroll
123
+
124
+ for (int j = 0; j < 4; j++)
125
+ {
126
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
127
+
128
+ for (int k = 0; k < K; k++)
129
+ {
130
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
131
+ }
132
+ }
133
+
134
+ #pragma unroll
135
+ for (int j = 0; j < 4; j++)
136
+ {
137
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
138
+ {
139
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
140
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
141
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
142
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
143
+ }
144
+ }
145
+
146
+ for (int j = 0; j < 4; j++)
147
+ {
148
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
149
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
150
+ }
151
+
152
+ __syncthreads();
153
+
154
+ #pragma unroll
155
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
156
+ {
157
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
158
+ shared_offset = i * 32 + threadIdx.x;
159
+
160
+ out_real[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
161
+ out_imag[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_imag_shared)[shared_offset];
162
+ }
163
+
164
+ __syncthreads();
165
+
166
+ }
167
+ }
168
+
169
+
170
+ template <int K>
171
+ __global__ void butterfly_padded_cuda_kernel_128(
172
+ const __half2 *__restrict__ x,
173
+ const __half2 *__restrict__ x_gate,
174
+ const complex_half_t *__restrict__ d_f,
175
+ const __half2 *__restrict__ twiddle_factors_real,
176
+ const __half2 *__restrict__ twiddle_factors_imag,
177
+ __half2 *__restrict__ out_real,
178
+ __half2 *__restrict__ out_imag,
179
+ uint B,
180
+ uint H,
181
+ int M)
182
+ {
183
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
184
+ const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
185
+ const int out_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x;
186
+ const int N = 128;
187
+ int idx;
188
+ int t_offset;
189
+ int out_t_offset;
190
+ int shared_offset;
191
+
192
+ extern __shared__ half shared_real[];
193
+ half *shared_imag = &shared_real[128 * 128];
194
+
195
+
196
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
197
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
198
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
199
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
200
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][8];
201
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
202
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
203
+
204
+ for (int i = threadIdx.y ; i < N; i+=blockDim.y)
205
+ {
206
+ for(int j=0; j< 4; j++){
207
+ shared_offset = i * 128 + threadIdx.x + j * blockDim.x;
208
+ shared_real[shared_offset] = d_f[shared_offset].real();
209
+ shared_imag[shared_offset] = d_f[shared_offset].imag();
210
+ }
211
+ }
212
+
213
+ __syncthreads();
214
+
215
+
216
+ for (int i = 0; i < 8; i++){
217
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
218
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
219
+ }
220
+
221
+
222
+ __syncthreads();
223
+
224
+
225
+
226
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
227
+ {
228
+ for(int j=0; j< 2; j++){
229
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
230
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
231
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[idx];
232
+ reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[idx];
233
+ }
234
+ }
235
+
236
+ __syncthreads();
237
+
238
+
239
+ for (int i = 0; i < 8; i++){
240
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
241
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
242
+ }
243
+
244
+ __syncthreads();
245
+
246
+
247
+ for(int t=0; t< 16; t++){
248
+ t_offset = t * M/2;
249
+ out_t_offset = t * 128 * 32 * 2 * gridDim.x;
250
+
251
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
252
+ {
253
+ if(i < K * 16){
254
+ for(int j=0; j< 2; j++){
255
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
256
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
257
+ if(x_gate != nullptr){
258
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2half2_rn(0.0f, 0.0f);
259
+ }
260
+ else{
261
+ reinterpret_cast<__half2*>(shared_real)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2half2_rn(0.0f, 0.0f);
262
+ }
263
+ }
264
+ }
265
+ }
266
+
267
+
268
+ __syncthreads();
269
+
270
+
271
+ for (int i = 0; i < K; i++)
272
+ {
273
+ for (int j = 0; j < 8; j++)
274
+ {
275
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
276
+ }
277
+ }
278
+
279
+ __syncthreads();
280
+
281
+ #pragma unroll
282
+ for (int j = 0; j < 8; j++)
283
+ {
284
+ wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
285
+
286
+ for (int k = 0; k < K; k++)
287
+ {
288
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
289
+ }
290
+ }
291
+
292
+ #pragma unroll
293
+
294
+ for (int j = 0; j < 8; j++)
295
+ {
296
+ wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
297
+
298
+ for (int k = 0; k < K; k++)
299
+ {
300
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
301
+ }
302
+ }
303
+
304
+ __half2 tmp_real, tmp_imag;
305
+ #pragma unroll
306
+ for (int j = 0; j < 8; j++)
307
+ {
308
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
309
+ {
310
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
311
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
312
+ reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
313
+ reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
314
+ }
315
+ }
316
+
317
+ for (int j = 0; j < 8; j++)
318
+ {
319
+ wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
320
+ wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
321
+ }
322
+
323
+ __syncthreads();
324
+
325
+ #pragma unroll
326
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
327
+ {
328
+ for(int j=0; j< 2; j++){
329
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
330
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
331
+
332
+ out_real[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
333
+ out_imag[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
334
+
335
+ }
336
+ }
337
+
338
+ __syncthreads();
339
+ }
340
+ }
341
+
342
+ template <int K>
343
+ __global__ void butterfly_padded_cuda_kernel_32(
344
+ const __half2 *__restrict__ x,
345
+ const __half2 *__restrict__ x_gate,
346
+ const complex_half_t *__restrict__ d_f,
347
+ const __half2 *__restrict__ twiddle_factors_real,
348
+ const __half2 *__restrict__ twiddle_factors_imag,
349
+ __half2 *__restrict__ out_real,
350
+ __half2 *__restrict__ out_imag,
351
+ uint B,
352
+ uint H,
353
+ int M)
354
+ {
355
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
356
+ const int N = 32;
357
+ __shared__ half x_shared[K * 16 * 64];
358
+ __shared__ half d_f_real[32 * 32];
359
+ __shared__ half d_f_imag[32 * 32];
360
+ __shared__ half twiddles_real_shared[32 * 64];
361
+ __shared__ half twiddles_imag_shared[32 * 64];
362
+ __shared__ half out_real_shared[32 * 64];
363
+ __shared__ half out_imag_shared[32 * 64];
364
+
365
+ const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
366
+ const int out_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
367
+
368
+
369
+ for(int i = threadIdx.y; i<32; i+=blockDim.y){
370
+ int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
371
+ int shared_offset = i * 32 + threadIdx.x;
372
+
373
+ if(i < K * 16){
374
+ if(x_gate != nullptr){
375
+ reinterpret_cast<__half2*>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[offset + idx], x_gate[offset + idx]) : __floats2half2_rn(0.0f, 0.0f);
376
+ }
377
+ else{
378
+ reinterpret_cast<__half2*>(x_shared)[shared_offset] = idx < max_idx ? x[offset + idx] : __floats2half2_rn(0.0f, 0.0f);
379
+ }
380
+ }
381
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
382
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
383
+
384
+ // #pragma unroll
385
+ d_f_real[shared_offset] = d_f[shared_offset].real();
386
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
387
+ }
388
+
389
+
390
+ __syncthreads();
391
+
392
+
393
+ if (threadIdx.y < N / 16)
394
+ {
395
+ __half2 tmp_real, tmp_imag;
396
+
397
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
398
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
399
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
400
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
401
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][2];
402
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
403
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
404
+
405
+ int t = threadIdx.y * 32;
406
+
407
+ for (int i = 0; i < 2; i++)
408
+ {
409
+ for (int j = 0; j < 2; j++)
410
+ {
411
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
412
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
413
+ if(i<K){
414
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
415
+ }
416
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
417
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
418
+ }
419
+ }
420
+
421
+ #pragma unroll
422
+ for (int i = 0; i < 2; i++)
423
+ {
424
+ for (int j = 0; j < 2; j++)
425
+ {
426
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
427
+
428
+ for (int k = 0; k < K; k++)
429
+ {
430
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
431
+ }
432
+ }
433
+ }
434
+
435
+ #pragma unroll
436
+ for (int i = 0; i < 2; i++)
437
+ {
438
+ for (int j = 0; j < 2; j++)
439
+ {
440
+ wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
441
+
442
+ for (int k = 0; k < K; k++)
443
+ {
444
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
445
+ }
446
+ }
447
+ }
448
+
449
+ #pragma unroll
450
+ for (int i = 0; i < 2; i++)
451
+ {
452
+ for (int j = 0; j < 2; j++)
453
+ {
454
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
455
+ {
456
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
457
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
458
+ reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
459
+ reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
460
+ }
461
+ }
462
+ }
463
+
464
+ for (int i = 0; i < 2; i++)
465
+ {
466
+ for (int j = 0; j < 2; j++)
467
+ {
468
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
469
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
470
+ }
471
+ }
472
+ }
473
+
474
+ __syncthreads();
475
+
476
+ // int idx = offset + threadIdx.y * 32 + blockIdx.x * 32 + threadIdx.x;
477
+ for(int i = threadIdx.y; i<32; i+=blockDim.y){
478
+ int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
479
+ out_real[out_offset + idx] = reinterpret_cast<__half2*>(out_real_shared)[i * 32 + threadIdx.x];
480
+ out_imag[out_offset + idx] = reinterpret_cast<__half2*>(out_imag_shared)[i * 32 + threadIdx.x];
481
+ }
482
+ }
483
+
484
+
485
+ __global__ void butterfly_padded_cuda_kernel_16(
486
+ const __half2 *__restrict__ x,
487
+ const __half2 *__restrict__ x_gate,
488
+ const complex_half_t *__restrict__ d_f,
489
+ const __half2 *__restrict__ twiddle_factors_real,
490
+ const __half2 *__restrict__ twiddle_factors_imag,
491
+ __half2 *__restrict__ out_real,
492
+ __half2 *__restrict__ out_imag,
493
+ uint B,
494
+ uint H,
495
+ int M)
496
+ {
497
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
498
+ const int N = 16;
499
+ const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
500
+ const int out_offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
501
+
502
+
503
+
504
+ __shared__ half x_shared[N * 64];
505
+ __shared__ half d_f_real[N * N];
506
+ __shared__ half d_f_imag[N * N];
507
+ __shared__ half twiddles_real_shared[N * 64];
508
+ __shared__ half twiddles_imag_shared[N * 64];
509
+ __shared__ half out_real_shared[N * 64];
510
+ __shared__ half out_imag_shared[N * 64];
511
+
512
+ // #pragma unroll
513
+ for(int i = threadIdx.y; i<N; i+=blockDim.y){
514
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
515
+ int shared_offset = i * blockDim.x + threadIdx.x;
516
+
517
+ if(x_gate != NULL){
518
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2half2_rn(0.0f, 0.0f);
519
+ }
520
+ else{
521
+ reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2half2_rn(0.0f, 0.0f);
522
+ }
523
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
524
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
525
+
526
+ // #pragma unroll
527
+
528
+ if(threadIdx.x < 16 ){
529
+ shared_offset = i * 16 + threadIdx.x;
530
+ d_f_real[shared_offset] = d_f[shared_offset].real();
531
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
532
+ }
533
+ }
534
+
535
+ __syncthreads();
536
+
537
+ if (threadIdx.y < 4)
538
+ {
539
+ __half2 tmp_real, tmp_imag;
540
+
541
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
542
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
543
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
544
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
545
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
546
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
547
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
548
+
549
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
550
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
551
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
552
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
553
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
554
+
555
+
556
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
557
+
558
+
559
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
560
+
561
+
562
+ wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
563
+
564
+
565
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
566
+
567
+
568
+
569
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
570
+ {
571
+ tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
572
+ tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
573
+ reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
574
+ reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
575
+ }
576
+
577
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
578
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
579
+ }
580
+
581
+ __syncthreads();
582
+
583
+ #pragma unroll
584
+ for (int i = threadIdx.y; i<N; i+=blockDim.y)
585
+ {
586
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
587
+ out_real[out_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x];
588
+ out_imag[out_offset + idx] = reinterpret_cast<__half2 *>(out_imag_shared)[i * 32 + threadIdx.x];
589
+ }
590
+ }
591
+
592
+ std::vector<torch::Tensor> butterfly_padded_cuda(
593
+ torch::Tensor x,
594
+ torch::Tensor d_f,
595
+ torch::Tensor twiddle_factors_real,
596
+ torch::Tensor twiddle_factors_imag,
597
+ int M,
598
+ std::optional<at::Tensor> x_gate = std::nullopt
599
+ )
600
+ {
601
+
602
+ uint B = x.size(0);
603
+ uint H = x.size(1);
604
+ uint N = x.size(2);
605
+
606
+ uint d_f_size = d_f.size(1);
607
+
608
+ //need to make sure that N is less that the M to which we are padding
609
+ assert(N <= d_f_size * M);
610
+ // printf("B: %d, H: %d, N: %d\n", B, H, N);
611
+ dim3 gridDim;
612
+ dim3 blockDim;
613
+
614
+ gridDim.y = B;
615
+ gridDim.z = H;
616
+
617
+ blockDim.x = 32;
618
+ blockDim.y = 4;
619
+
620
+ torch::Tensor out_real = torch::empty({B, H, d_f_size * M}, x.options());
621
+ torch::Tensor out_imag = torch::empty({B, H, d_f_size * M}, x.options());
622
+
623
+ gridDim.x = 512 / (32 * 1024/ M);
624
+
625
+ const int K = ceil(N / (1.0 * 16 * M));
626
+
627
+
628
+ switch(d_f_size){
629
+ case 16:
630
+ butterfly_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
631
+ static_cast<__half2 *>(x.data_ptr()),
632
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
633
+ static_cast<complex_half_t *>(d_f.data_ptr()),
634
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
635
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
636
+ static_cast<__half2 *>(out_real.data_ptr()),
637
+ static_cast<__half2 *>(out_imag.data_ptr()),
638
+ B,
639
+ H,
640
+ N);
641
+ break;
642
+ case 32:
643
+ switch (K)
644
+ {
645
+ case 1:
646
+ butterfly_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
647
+ static_cast<__half2 *>(x.data_ptr()),
648
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
649
+ static_cast<complex_half_t *>(d_f.data_ptr()),
650
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
651
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
652
+ static_cast<__half2 *>(out_real.data_ptr()),
653
+ static_cast<__half2 *>(out_imag.data_ptr()),
654
+ B,
655
+ H,
656
+ N);
657
+ break;
658
+ case 2:
659
+ butterfly_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
660
+ static_cast<__half2 *>(x.data_ptr()),
661
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
662
+ static_cast<complex_half_t *>(d_f.data_ptr()),
663
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
664
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
665
+ static_cast<__half2 *>(out_real.data_ptr()),
666
+ static_cast<__half2 *>(out_imag.data_ptr()),
667
+ B,
668
+ H,
669
+ N);
670
+ break;
671
+ default:
672
+ printf("Invalid K, df size 32: %d\n", K);
673
+ }
674
+ break;
675
+ case 64:
676
+ gridDim.z = H / 16;
677
+
678
+ switch (K)
679
+ {
680
+ case 1:
681
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
682
+ butterfly_padded_cuda_kernel_64<1><<<gridDim, blockDim, 65536>>>(
683
+ static_cast<__half2 *>(x.data_ptr()),
684
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
685
+ static_cast<complex_half_t *>(d_f.data_ptr()),
686
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
687
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
688
+ static_cast<__half2 *>(out_real.data_ptr()),
689
+ static_cast<__half2 *>(out_imag.data_ptr()),
690
+ B,
691
+ H,
692
+ N);
693
+ break;
694
+
695
+ case 2:
696
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
697
+ butterfly_padded_cuda_kernel_64<2><<<gridDim, blockDim, 65536>>>(
698
+ static_cast<__half2 *>(x.data_ptr()),
699
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
700
+ static_cast<complex_half_t *>(d_f.data_ptr()),
701
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
702
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
703
+ static_cast<__half2 *>(out_real.data_ptr()),
704
+ static_cast<__half2 *>(out_imag.data_ptr()),
705
+ B,
706
+ H,
707
+ N);
708
+ break;
709
+
710
+ case 3:
711
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
712
+ butterfly_padded_cuda_kernel_64<3><<<gridDim, blockDim, 65536>>>(
713
+ static_cast<__half2 *>(x.data_ptr()),
714
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
715
+ static_cast<complex_half_t *>(d_f.data_ptr()),
716
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
717
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
718
+ static_cast<__half2 *>(out_real.data_ptr()),
719
+ static_cast<__half2 *>(out_imag.data_ptr()),
720
+ B,
721
+ H,
722
+ N);
723
+ break;
724
+
725
+ case 4:
726
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
727
+ butterfly_padded_cuda_kernel_64<4><<<gridDim, blockDim, 65536>>>(
728
+ static_cast<__half2 *>(x.data_ptr()),
729
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
730
+ static_cast<complex_half_t *>(d_f.data_ptr()),
731
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
732
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
733
+ static_cast<__half2 *>(out_real.data_ptr()),
734
+ static_cast<__half2 *>(out_imag.data_ptr()),
735
+ B,
736
+ H,
737
+ N);
738
+ break;
739
+
740
+ default:
741
+ printf("Invalid K, df size 64: %d\n", K);
742
+ }
743
+ break;
744
+ case 128:
745
+ blockDim.x = 32;
746
+ blockDim.y = 8;
747
+ gridDim.x = 256 / (32 * 1024/ M);
748
+ gridDim.z = H / 16;
749
+
750
+ switch(K){
751
+ case 1:
752
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
753
+ butterfly_padded_cuda_kernel_128<1><<<gridDim, blockDim, 65536>>>(
754
+ static_cast<__half2 *>(x.data_ptr()),
755
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
756
+ static_cast<complex_half_t *>(d_f.data_ptr()),
757
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
758
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
759
+ static_cast<__half2 *>(out_real.data_ptr()),
760
+ static_cast<__half2 *>(out_imag.data_ptr()),
761
+ B,
762
+ H,
763
+ N);
764
+ break;
765
+ case 2:
766
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
767
+ butterfly_padded_cuda_kernel_128<2><<<gridDim, blockDim, 65536>>>(
768
+ static_cast<__half2 *>(x.data_ptr()),
769
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
770
+ static_cast<complex_half_t *>(d_f.data_ptr()),
771
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
772
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
773
+ static_cast<__half2 *>(out_real.data_ptr()),
774
+ static_cast<__half2 *>(out_imag.data_ptr()),
775
+ B,
776
+ H,
777
+ N);
778
+ break;
779
+ case 3:
780
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
781
+ butterfly_padded_cuda_kernel_128<3><<<gridDim, blockDim, 65536>>>(
782
+ static_cast<__half2 *>(x.data_ptr()),
783
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
784
+ static_cast<complex_half_t *>(d_f.data_ptr()),
785
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
786
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
787
+ static_cast<__half2 *>(out_real.data_ptr()),
788
+ static_cast<__half2 *>(out_imag.data_ptr()),
789
+ B,
790
+ H,
791
+ N);
792
+ break;
793
+ case 4:
794
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
795
+ butterfly_padded_cuda_kernel_128<4><<<gridDim, blockDim, 65536>>>(
796
+ static_cast<__half2 *>(x.data_ptr()),
797
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
798
+ static_cast<complex_half_t *>(d_f.data_ptr()),
799
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
800
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
801
+ static_cast<__half2 *>(out_real.data_ptr()),
802
+ static_cast<__half2 *>(out_imag.data_ptr()),
803
+ B,
804
+ H,
805
+ N);
806
+ break;
807
+ case 5:
808
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
809
+ butterfly_padded_cuda_kernel_128<5><<<gridDim, blockDim, 65536>>>(
810
+ static_cast<__half2 *>(x.data_ptr()),
811
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
812
+ static_cast<complex_half_t *>(d_f.data_ptr()),
813
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
814
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
815
+ static_cast<__half2 *>(out_real.data_ptr()),
816
+ static_cast<__half2 *>(out_imag.data_ptr()),
817
+ B,
818
+ H,
819
+ N);
820
+ break;
821
+ case 6:
822
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
823
+ butterfly_padded_cuda_kernel_128<6><<<gridDim, blockDim, 65536>>>(
824
+ static_cast<__half2 *>(x.data_ptr()),
825
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
826
+ static_cast<complex_half_t *>(d_f.data_ptr()),
827
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
828
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
829
+ static_cast<__half2 *>(out_real.data_ptr()),
830
+ static_cast<__half2 *>(out_imag.data_ptr()),
831
+ B,
832
+ H,
833
+ N);
834
+ break;
835
+ case 7:
836
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
837
+ butterfly_padded_cuda_kernel_128<7><<<gridDim, blockDim, 65536>>>(
838
+ static_cast<__half2 *>(x.data_ptr()),
839
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
840
+ static_cast<complex_half_t *>(d_f.data_ptr()),
841
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
842
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
843
+ static_cast<__half2 *>(out_real.data_ptr()),
844
+ static_cast<__half2 *>(out_imag.data_ptr()),
845
+ B,
846
+ H,
847
+ N);
848
+ break;
849
+ case 8:
850
+ cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
851
+ butterfly_padded_cuda_kernel_128<8><<<gridDim, blockDim, 65536>>>(
852
+ static_cast<__half2 *>(x.data_ptr()),
853
+ x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
854
+ static_cast<complex_half_t *>(d_f.data_ptr()),
855
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
856
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
857
+ static_cast<__half2 *>(out_real.data_ptr()),
858
+ static_cast<__half2 *>(out_imag.data_ptr()),
859
+ B,
860
+ H,
861
+ N);
862
+ break;
863
+ default:
864
+ printf("Invalid K, df size 128: %d\n", K);
865
+ }
866
+ break;
867
+ default:
868
+ printf("Invalid d_f size: %d\n", d_f_size);
869
+ }
870
+ return {out_real, out_imag};
871
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda_bf16.cu ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_runtime.h>
9
+ #include <cuda_fp16.h>
10
+ #include <cuda_bf16.h>
11
+ #include "shared.h"
12
+
13
+ using namespace nvcuda;
14
+
15
+
16
+ template <int K>
17
+ __global__ void butterfly_cuda_kernel_64(
18
+ const __nv_bfloat162 *__restrict__ x,
19
+ const __nv_bfloat162 *__restrict__ x_gate,
20
+ const __nv_bfloat162 *__restrict__ d_f_real,
21
+ const __nv_bfloat162 *__restrict__ d_f_imag,
22
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
23
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
24
+ __nv_bfloat162 *__restrict__ out_real,
25
+ __nv_bfloat162 *__restrict__ out_imag,
26
+ uint B,
27
+ uint H,
28
+ int M)
29
+ {
30
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
31
+ const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
32
+ const int out_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x;
33
+ int idx;
34
+ int t_offset;
35
+ int out_t_offset;
36
+ int shared_offset;
37
+ const int N = 64;
38
+
39
+
40
+ extern __shared__ __nv_bfloat16 x_shared[];
41
+ __nv_bfloat16 *d_f_real_shared = &x_shared[K * 16 * N];
42
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
43
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
44
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
45
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
46
+ float *out_imag_shared = &out_real_shared[N * N];
47
+
48
+ // #pragma unroll
49
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
50
+ {
51
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
52
+ shared_offset = i * 32 + threadIdx.x;
53
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
54
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
55
+
56
+ // #pragma unroll
57
+ shared_offset = i * 32 + threadIdx.x;
58
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
59
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
60
+ }
61
+
62
+ float2 tmp_real, tmp_imag;
63
+
64
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
65
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
66
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
67
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
68
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
69
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
70
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
71
+
72
+ __syncthreads();
73
+
74
+ for (int i = 0; i < 4; i++)
75
+ {
76
+ wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
77
+ wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
78
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
79
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
80
+ }
81
+
82
+ for (int t = 0; t < 16; t++)
83
+ {
84
+ t_offset = t * M/2;
85
+ out_t_offset = t * 64 * 32 * gridDim.x;
86
+
87
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
88
+ {
89
+ if(i < K * 16){
90
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
91
+ shared_offset = i * 32 + threadIdx.x;
92
+ if(x_gate != nullptr){
93
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
94
+ }else{
95
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2bfloat162_rn(0.0f, 0.0f);
96
+ }
97
+ }
98
+ }
99
+
100
+ __syncthreads();
101
+
102
+ for (int i = 0; i < K; i++)
103
+ {
104
+ for (int j = 0; j < 4; j++)
105
+ {
106
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
107
+ }
108
+ }
109
+
110
+ #pragma unroll
111
+ for (int j = 0; j < 4; j++)
112
+ {
113
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
114
+
115
+ for (int k = 0; k < K; k++)
116
+ {
117
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
118
+ }
119
+ }
120
+
121
+ #pragma unroll
122
+
123
+ for (int j = 0; j < 4; j++)
124
+ {
125
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
126
+
127
+ for (int k = 0; k < K; k++)
128
+ {
129
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
130
+ }
131
+ }
132
+
133
+ #pragma unroll
134
+ for (int j = 0; j < 4; j++)
135
+ {
136
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
137
+ {
138
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
139
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
140
+
141
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
142
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
143
+ }
144
+
145
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
146
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
147
+ }
148
+
149
+ __syncthreads();
150
+
151
+ #pragma unroll
152
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
153
+ {
154
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
155
+ shared_offset = i * 32 + threadIdx.x;
156
+ out_real[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[shared_offset]);
157
+ out_imag[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[shared_offset]);
158
+ }
159
+
160
+ __syncthreads();
161
+ }
162
+ }
163
+
164
+ template <int K>
165
+ __global__ void butterfly_cuda_kernel_32(
166
+ const __nv_bfloat162 *__restrict__ x,
167
+ const __nv_bfloat162 *__restrict__ x_gate,
168
+ const __nv_bfloat16 *__restrict__ d_f_real,
169
+ const __nv_bfloat16 *__restrict__ d_f_imag,
170
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
171
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
172
+ __nv_bfloat162 *__restrict__ out_real,
173
+ __nv_bfloat162 *__restrict__ out_imag,
174
+ uint B,
175
+ uint H,
176
+ int M)
177
+ {
178
+ const int N = 32;
179
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
180
+
181
+ const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
182
+ const int out_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
183
+
184
+
185
+ __shared__ __nv_bfloat16 x_shared[K * 16 * 64];
186
+ __shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
187
+ __shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
188
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
189
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
190
+ __shared__ float out_real_shared[32 * 64];
191
+ __shared__ float out_imag_shared[32 * 64];
192
+
193
+ // #pragma unroll
194
+ for (int i = threadIdx.y; i<32; i+=blockDim.y)
195
+ {
196
+ int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
197
+ int shared_offset = i * 32 + threadIdx.x;
198
+
199
+ if(i < K * 16){
200
+ if(x_gate != nullptr){
201
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
202
+ }else{
203
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2bfloat162_rn(0.0f, 0.0f);
204
+ }
205
+ }
206
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
207
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
208
+
209
+ // #pragma unroll
210
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
211
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
212
+ }
213
+
214
+ __syncthreads();
215
+
216
+ if (threadIdx.y < N / 16)
217
+ {
218
+ float2 tmp_real, tmp_imag;
219
+
220
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
221
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
222
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
223
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
224
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[K][2];
225
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
226
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
227
+
228
+ int t = threadIdx.y * 32;
229
+
230
+ for (int i = 0; i < 2; i++)
231
+ {
232
+ for (int j = 0; j < 2; j++)
233
+ {
234
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
235
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
236
+ if(i < K){
237
+ wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
238
+ }
239
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
240
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
241
+ }
242
+ }
243
+
244
+ #pragma unroll
245
+ for (int i = 0; i < 2; i++)
246
+ {
247
+ for (int j = 0; j < 2; j++)
248
+ {
249
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
250
+
251
+ for (int k = 0; k < K; k++)
252
+ {
253
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
254
+ }
255
+ }
256
+ }
257
+
258
+ #pragma unroll
259
+ for (int i = 0; i < 2; i++)
260
+ {
261
+ for (int j = 0; j < 2; j++)
262
+ {
263
+ wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
264
+
265
+ for (int k = 0; k < K; k++)
266
+ {
267
+ wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
268
+ }
269
+ }
270
+ }
271
+
272
+ #pragma unroll
273
+ for (int i = 0; i < 2; i++)
274
+ {
275
+ for (int j = 0; j < 2; j++)
276
+ {
277
+ for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
278
+ {
279
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
280
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
281
+ reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
282
+ reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
283
+ }
284
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
285
+ wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
286
+ }
287
+ }
288
+ }
289
+
290
+ __syncthreads();
291
+
292
+ #pragma unroll
293
+ for (int i = threadIdx.y; i<32; i+=blockDim.y)
294
+ {
295
+ int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
296
+ out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[i * 32 + threadIdx.x]);
297
+ out_imag[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[i * 32 + threadIdx.x]);
298
+ }
299
+ }
300
+
301
+ template <int K>
302
+ __global__ void butterfly_cuda_kernel_128(
303
+ const __nv_bfloat162 *__restrict__ x,
304
+ const __nv_bfloat162 *__restrict__ x_gate,
305
+ const __nv_bfloat162 *__restrict__ d_f_real,
306
+ const __nv_bfloat162 *__restrict__ d_f_imag,
307
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
308
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
309
+ __nv_bfloat162 *__restrict__ out_real,
310
+ __nv_bfloat162 *__restrict__ out_imag,
311
+ uint B,
312
+ uint H,
313
+ int M)
314
+ {
315
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
316
+ const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
317
+ const int out_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x;
318
+ const int N = 128;
319
+ int idx;
320
+ int t_offset;
321
+ int out_t_offset;
322
+ int shared_offset;
323
+
324
+ extern __shared__ __nv_bfloat16 shared_real[];
325
+ __nv_bfloat16 *shared_imag = &shared_real[128 * 128];
326
+
327
+
328
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
329
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
330
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
331
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
332
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[K][8];
333
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
334
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
335
+
336
+ for (int i = threadIdx.y ; i < N; i+=blockDim.y)
337
+ {
338
+ for(int j=0; j< 2; j++){
339
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
340
+ reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
341
+ reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
342
+ }
343
+ }
344
+
345
+ __syncthreads();
346
+
347
+
348
+ for (int i = 0; i < 8; i++){
349
+ wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
350
+ wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
351
+ }
352
+
353
+
354
+ __syncthreads();
355
+
356
+
357
+
358
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
359
+ {
360
+ for(int j=0; j< 2; j++){
361
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
362
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
363
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[idx];
364
+ reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[idx];
365
+ }
366
+ }
367
+
368
+ __syncthreads();
369
+
370
+
371
+ for (int i = 0; i < 8; i++){
372
+ wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
373
+ wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
374
+ }
375
+
376
+ __syncthreads();
377
+
378
+
379
+ for(int t=0; t< 16; t++){
380
+ t_offset = t * M/2;
381
+ out_t_offset = t * 128 * 32 * 2 * gridDim.x;
382
+
383
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
384
+ {
385
+ if(i < K * 16){
386
+ for(int j=0; j< 2; j++){
387
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
388
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
389
+ if(x_gate != nullptr){
390
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
391
+ }else{
392
+ reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2bfloat162_rn(0.0f, 0.0f);
393
+ }
394
+ }
395
+ }
396
+ }
397
+
398
+
399
+ __syncthreads();
400
+
401
+
402
+ for (int i = 0; i < K; i++)
403
+ {
404
+ for (int j = 0; j < 8; j++)
405
+ {
406
+ wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
407
+ }
408
+ }
409
+
410
+ __syncthreads();
411
+
412
+ #pragma unroll
413
+ for (int j = 0; j < 8; j++)
414
+ {
415
+ wmma::fill_fragment(acc_frag_real[j], 0.0f);
416
+
417
+ for (int k = 0; k < K; k++)
418
+ {
419
+ wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
420
+ }
421
+ }
422
+
423
+ #pragma unroll
424
+
425
+ for (int j = 0; j < 8; j++)
426
+ {
427
+ wmma::fill_fragment(acc_frag_imag[j], 0.0f);
428
+
429
+ for (int k = 0; k < K; k++)
430
+ {
431
+ wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
432
+ }
433
+ }
434
+
435
+ float2 tmp_real, tmp_imag;
436
+ #pragma unroll
437
+ for (int j = 0; j < 8; j++)
438
+ {
439
+ for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
440
+ {
441
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
442
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
443
+
444
+ reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
445
+ reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
446
+ }
447
+ }
448
+
449
+ for (int j = 0; j < 8; j++)
450
+ {
451
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
452
+ }
453
+
454
+ __syncthreads();
455
+
456
+ #pragma unroll
457
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
458
+ {
459
+ for(int j=0; j< 2; j++){
460
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
461
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
462
+ out_real[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
463
+ }
464
+ }
465
+
466
+ __syncthreads();
467
+
468
+
469
+ for (int j = 0; j < 8; j++)
470
+ {
471
+ wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
472
+ }
473
+
474
+ __syncthreads();
475
+
476
+ #pragma unroll
477
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
478
+ {
479
+ for(int j=0; j< 2; j++){
480
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
481
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
482
+ out_imag[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
483
+ }
484
+ }
485
+ }
486
+ }
487
+
488
+ template<int K>
489
+ __global__ void butterfly_cuda_kernel_16(
490
+ const __nv_bfloat162 *__restrict__ x,
491
+ const __nv_bfloat162 *__restrict__ x_gate,
492
+ const __nv_bfloat16 *__restrict__ d_f_real,
493
+ const __nv_bfloat16 *__restrict__ d_f_imag,
494
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
495
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
496
+ __nv_bfloat162 *__restrict__ out_real,
497
+ __nv_bfloat162 *__restrict__ out_imag,
498
+ uint B,
499
+ uint H,
500
+ int M)
501
+ {
502
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
503
+ const int N = 16;
504
+ const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
505
+ const int out_offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
506
+
507
+
508
+
509
+ __shared__ __nv_bfloat16 x_shared[N * 64];
510
+ __shared__ __nv_bfloat16 d_f_real_shared[N * N];
511
+ __shared__ __nv_bfloat16 d_f_imag_shared[N * N];
512
+ __shared__ __nv_bfloat16 twiddles_real_shared[N * 64];
513
+ __shared__ __nv_bfloat16 twiddles_imag_shared[N * 64];
514
+ __shared__ float out_real_shared[N * 64];
515
+ __shared__ float out_imag_shared[N * 64];
516
+
517
+ // #pragma unroll
518
+ for (int i = threadIdx.y; i < N; i++)
519
+ {
520
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
521
+ int shared_offset = i * blockDim.x + threadIdx.x;
522
+
523
+ if(x_gate != nullptr){
524
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
525
+ }else{
526
+ reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2bfloat162_rn(0.0f, 0.0f);
527
+ }
528
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
529
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
530
+
531
+ // #pragma unroll
532
+ if(threadIdx.x < 16 ){
533
+ shared_offset = i * 16 + threadIdx.x;
534
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
535
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
536
+ }
537
+ }
538
+
539
+ __syncthreads();
540
+
541
+ if (threadIdx.y < 4)
542
+ {
543
+ float2 tmp_real, tmp_imag;
544
+
545
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
546
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
547
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
548
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
549
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
550
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
551
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
552
+
553
+
554
+ wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
555
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
556
+ wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
557
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
558
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
559
+
560
+
561
+
562
+ wmma::fill_fragment(acc_frag_real, 0.0f);
563
+
564
+
565
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
566
+
567
+
568
+
569
+ wmma::fill_fragment(acc_frag_imag, 0.0f);
570
+
571
+
572
+ wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
573
+
574
+
575
+ #pragma unroll
576
+ for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
577
+ {
578
+ tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
579
+ tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
580
+ reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
581
+ reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
582
+ }
583
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
584
+ wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
585
+
586
+ }
587
+ __syncthreads();
588
+
589
+ #pragma unroll
590
+ for (int i = threadIdx.y; i < N; i++)
591
+ {
592
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;;
593
+ out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[i * 32 + threadIdx.x]);
594
+ out_imag[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[i * 32 + threadIdx.x]);
595
+ }
596
+ }
597
+
598
+ std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
599
+ torch::Tensor x,
600
+ torch::Tensor d_f_real,
601
+ torch::Tensor d_f_imag,
602
+ torch::Tensor twiddle_factors_real,
603
+ torch::Tensor twiddle_factors_imag,
604
+ int M,
605
+ std::optional<at::Tensor> x_gate = std::nullopt
606
+ )
607
+ {
608
+
609
+ uint B = x.size(0);
610
+ uint H = x.size(1);
611
+
612
+ uint d_f_size = d_f_real.size(1);
613
+
614
+ uint N = x.size(2);
615
+
616
+ //need to make sure that N is less that the M to which we are padding
617
+ assert(N <= d_f_size * M);
618
+
619
+ dim3 gridDim;
620
+ dim3 blockDim;
621
+
622
+ gridDim.y = B;
623
+ gridDim.z = H;
624
+
625
+ blockDim.x = 32;
626
+ blockDim.y = 4;
627
+
628
+ torch::Tensor out_real = torch::empty({B, H, d_f_size * M}, x.options());
629
+ torch::Tensor out_imag = torch::empty({B, H, d_f_size * M}, x.options());
630
+
631
+ gridDim.x = 512 / (32 * 1024/ M);
632
+
633
+ const int K = ceil(N / (1.0 * 16 * M));
634
+
635
+ switch (d_f_size)
636
+ {
637
+ case 16:
638
+ butterfly_cuda_kernel_16<1><<<gridDim, blockDim>>>(
639
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
640
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
641
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
642
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
643
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
644
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
645
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
646
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
647
+ B,
648
+ H,
649
+ N);
650
+ break;
651
+ case 32:
652
+ switch(K){
653
+ case 1:
654
+ butterfly_cuda_kernel_32<1><<<gridDim, blockDim>>>(
655
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
656
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
657
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
658
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
659
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
660
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
661
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
662
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
663
+ B,
664
+ H,
665
+ N);
666
+ break;
667
+ case 2:
668
+ butterfly_cuda_kernel_32<2><<<gridDim, blockDim>>>(
669
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
670
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
671
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
672
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
673
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
674
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
676
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
677
+ B,
678
+ H,
679
+ N);
680
+ break;
681
+ default:
682
+ printf("Invalid K, df size 32: %d\n", K);
683
+ }
684
+ break;
685
+ case 64:
686
+ gridDim.z = H / 16;
687
+
688
+ switch(K){
689
+ case 1:
690
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
691
+ butterfly_cuda_kernel_64<1><<<gridDim, blockDim, 78000>>>(
692
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
693
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
694
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
695
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
696
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
697
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
698
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
699
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
700
+ B,
701
+ H,
702
+ N);
703
+ break;
704
+ case 2:
705
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
706
+ butterfly_cuda_kernel_64<2><<<gridDim, blockDim, 78000>>>(
707
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
708
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
709
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
710
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
711
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
712
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
713
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
714
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
715
+ B,
716
+ H,
717
+ N);
718
+ break;
719
+ case 3:
720
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
721
+ butterfly_cuda_kernel_64<3><<<gridDim, blockDim, 78000>>>(
722
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
723
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
724
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
725
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
726
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
727
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
728
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
729
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
730
+ B,
731
+ H,
732
+ N);
733
+ break;
734
+ case 4:
735
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_64<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
736
+ butterfly_cuda_kernel_64<4><<<gridDim, blockDim, 78000>>>(
737
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
738
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
739
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
740
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
741
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
742
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
743
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
744
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
745
+ B,
746
+ H,
747
+ N);
748
+ break;
749
+ default:
750
+ printf("Invalid K, df size 64: %d\n", K);
751
+ }
752
+ break;
753
+ case 128:
754
+ blockDim.x = 32;
755
+ blockDim.y = 8;
756
+ gridDim.x = 256 / (32 * 1024/ M);
757
+ gridDim.z = H / 16;
758
+ switch(K){
759
+ case 1:
760
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
761
+ butterfly_cuda_kernel_128<1><<<gridDim, blockDim, 65536>>>(
762
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
763
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
764
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
765
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
766
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
767
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
768
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
769
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
770
+ B,
771
+ H,
772
+ N);
773
+ break;
774
+ case 2:
775
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
776
+ butterfly_cuda_kernel_128<2><<<gridDim, blockDim, 65536>>>(
777
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
778
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
779
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
780
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
781
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
782
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
783
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
784
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
785
+ B,
786
+ H,
787
+ N);
788
+ break;
789
+ case 3:
790
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
791
+
792
+ butterfly_cuda_kernel_128<3><<<gridDim, blockDim, 65536>>>(
793
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
794
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
795
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
796
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
797
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
798
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
799
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
800
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
801
+ B,
802
+ H,
803
+ N);
804
+ break;
805
+ case 4:
806
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
807
+
808
+ butterfly_cuda_kernel_128<4><<<gridDim, blockDim, 65536>>>(
809
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
810
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
811
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
812
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
813
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
814
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
815
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
816
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
817
+ B,
818
+ H,
819
+ N);
820
+ break;
821
+ case 5:
822
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
823
+
824
+ butterfly_cuda_kernel_128<5><<<gridDim, blockDim, 65536>>>(
825
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
826
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
827
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
828
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
829
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
830
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
831
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
832
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
833
+ B,
834
+ H,
835
+ N);
836
+ break;
837
+ case 6:
838
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
839
+
840
+ butterfly_cuda_kernel_128<6><<<gridDim, blockDim, 65536>>>(
841
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
842
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
843
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
844
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
845
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
846
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
847
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
848
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
849
+ B,
850
+ H,
851
+ N);
852
+ break;
853
+ case 7:
854
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
855
+
856
+ butterfly_cuda_kernel_128<7><<<gridDim, blockDim, 65536>>>(
857
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
858
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
859
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
860
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
861
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
862
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
863
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
864
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
865
+ B,
866
+ H,
867
+ N);
868
+ break;
869
+ case 8:
870
+ cudaFuncSetAttribute(&butterfly_cuda_kernel_128<8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
871
+
872
+ butterfly_cuda_kernel_128<8><<<gridDim, blockDim, 65536>>>(
873
+ static_cast<__nv_bfloat162 *>(x.data_ptr()),
874
+ x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
875
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
876
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
877
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
878
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
879
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
880
+ static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
881
+ B,
882
+ H,
883
+ N);
884
+ break;
885
+ default:
886
+ printf("Invalid K, df size 128: %d\n", K);
887
+
888
+ }
889
+ break;
890
+
891
+ default:
892
+ printf("Not yet implemented \n");
893
+ break;
894
+ }
895
+
896
+ return {out_real, out_imag};
897
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda.cu ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ template <int TILE_H, int K>
15
+ __global__ void butterfly_ifft_padded_cuda_kernel_64(
16
+ const __half2 *__restrict__ x_real,
17
+ const __half2 *__restrict__ x_imag,
18
+ const complex_half_t *__restrict__ d_f,
19
+ const __half2 *__restrict__ twiddle_factors_real,
20
+ const __half2 *__restrict__ twiddle_factors_imag,
21
+ __half2 *__restrict__ out_real,
22
+ __half2 *__restrict__ out_gate,
23
+ uint B,
24
+ uint H,
25
+ int M)
26
+ {
27
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
28
+ const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
29
+ const int in_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * TILE_H * 64 * 32 * gridDim.x;
30
+ int idx;
31
+ int t_offset;
32
+ int out_t_offset;
33
+ int shared_offset;
34
+ const int N = 64;
35
+
36
+ extern __shared__ half x_real_shared[];
37
+ half *x_imag_shared = &x_real_shared[N * N];
38
+ half *d_f_real = &x_imag_shared[N * N];
39
+ half *d_f_imag = &d_f_real[N * N];
40
+ half *twiddles_real_shared = &d_f_imag[N * N];
41
+ half *twiddles_imag_shared = &twiddles_real_shared[N * N];
42
+ half *out_real_shared = &twiddles_imag_shared[N * N];
43
+
44
+ half tmp_real, tmp_imag;
45
+
46
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[K][4];
47
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[K][4];
48
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
49
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
50
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
51
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
52
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K];
53
+
54
+ // #pragma unroll
55
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
56
+ {
57
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
58
+ shared_offset = i * 32 + threadIdx.x;
59
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
60
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
61
+
62
+ // #pragma unroll
63
+ shared_offset = i * 64 + threadIdx.x;
64
+ d_f_real[shared_offset] = d_f[shared_offset].real();
65
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
66
+
67
+ d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
68
+ d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
69
+ }
70
+
71
+ __syncthreads();
72
+
73
+ for (int i = 0; i < 4; i++)
74
+ {
75
+ if(i < K){
76
+ #pragma unroll
77
+ for (int j = 0; j < 4; j++)
78
+ {
79
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
80
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
81
+ }
82
+ }
83
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
84
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
85
+ }
86
+
87
+ for (int t = 0; t < TILE_H; t++)
88
+ {
89
+
90
+ out_t_offset = t * M/2;
91
+ t_offset = t * 64 * 32 * gridDim.x;
92
+
93
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
94
+ {
95
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
96
+ shared_offset = i * 32 + threadIdx.x;
97
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
98
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
99
+ }
100
+
101
+ __syncthreads();
102
+
103
+ for (int i = 0; i < 4; i++)
104
+ {
105
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
106
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
107
+ }
108
+
109
+ for (int j = 0; j < 4; j++)
110
+ {
111
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
112
+ {
113
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
114
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
115
+ b_frag_real[j].x[k] = tmp_real;
116
+ b_frag_imag[j].x[k] = tmp_imag;
117
+ }
118
+ }
119
+
120
+ for (int i = 0; i < K; i++)
121
+ {
122
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
123
+
124
+ // bd
125
+ #pragma unroll
126
+ for (int k = 0; k < 4; k++)
127
+ {
128
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
129
+ }
130
+
131
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
132
+ {
133
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
134
+ }
135
+ }
136
+
137
+ for (int i = 0; i < K; i++)
138
+ {
139
+ // ac - bd
140
+ #pragma unroll
141
+ for (int k = 0; k < 4; k++)
142
+ {
143
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
144
+ }
145
+ }
146
+
147
+ #pragma unroll
148
+ for (int i = 0; i < K; i++)
149
+ {
150
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
151
+ }
152
+
153
+ __syncthreads();
154
+
155
+ #pragma unroll
156
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
157
+ {
158
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
159
+ shared_offset = i * 32 + threadIdx.x;
160
+
161
+ if(idx < max_idx){
162
+ if(out_gate != nullptr)
163
+ out_real[out_offset + out_t_offset + idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[shared_offset], out_gate[out_offset + out_t_offset + idx]);
164
+ else
165
+ out_real[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
166
+ }
167
+ }
168
+
169
+ __syncthreads();
170
+ }
171
+ }
172
+
173
+
174
+ template <int K>
175
+ __global__ void butterfly_ifft_padded_cuda_kernel_32(
176
+ const __half2 *__restrict__ x_real,
177
+ const __half2 *__restrict__ x_imag,
178
+ const complex_half_t *__restrict__ d_f,
179
+ const __half2 *__restrict__ twiddle_factors_real,
180
+ const __half2 *__restrict__ twiddle_factors_imag,
181
+ __half2 *__restrict__ out_real,
182
+ __half2 *__restrict__ out_gate,
183
+ uint B,
184
+ uint H,
185
+ int M)
186
+ {
187
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
188
+ const int N = 32;
189
+ int idx;
190
+ int shared_offset;
191
+
192
+ const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
193
+ const int in_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
194
+
195
+
196
+ __shared__ half x_real_shared[32 * 64];
197
+ __shared__ half x_imag_shared[32 * 64];
198
+ __shared__ half d_f_real[32 * 32];
199
+ __shared__ half d_f_imag[32 * 32];
200
+ __shared__ half twiddles_real_shared[32 * 64];
201
+ __shared__ half twiddles_imag_shared[32 * 64];
202
+ __shared__ half out_real_shared[32 * 64];
203
+
204
+ // #pragma unroll
205
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
206
+ {
207
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
208
+ int shared_offset = i * 32 + threadIdx.x;
209
+
210
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[in_offset + idx];
211
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[in_offset + idx];
212
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
213
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
214
+
215
+ // #pragma unroll
216
+ shared_offset = i * 32 + threadIdx.x;
217
+ d_f_real[shared_offset] = d_f[shared_offset].real();
218
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
219
+ }
220
+
221
+ __syncthreads();
222
+
223
+ if (threadIdx.y < N/16)
224
+ {
225
+ half tmp_real, tmp_imag;
226
+
227
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[K][2];
228
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[K][2];
229
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
230
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
231
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
232
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
233
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K][2];
234
+
235
+ int t = threadIdx.y * 32;
236
+
237
+ for (int i = 0; i < 2; i++)
238
+ {
239
+ for (int j = 0; j < 2; j++)
240
+ {
241
+ if(i < K){
242
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
243
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
244
+ }
245
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
246
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
247
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
248
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
249
+ }
250
+ }
251
+
252
+ for (int i = 0; i < 2; i++)
253
+ {
254
+ for (int j = 0; j < 2; j++)
255
+ {
256
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
257
+ {
258
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
259
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
260
+ b_frag_real[i][j].x[k] = tmp_real;
261
+ b_frag_imag[i][j].x[k] = tmp_imag;
262
+ }
263
+ }
264
+ }
265
+
266
+ for (int i = 0; i < K; i++)
267
+ {
268
+ for (int j = 0; j < 2; j++)
269
+ {
270
+ wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
271
+
272
+ // bd
273
+ for (int k = 0; k < 2; k++)
274
+ {
275
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
276
+ }
277
+
278
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
279
+ {
280
+ acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
281
+ }
282
+ }
283
+ }
284
+
285
+ for (int i = 0; i < K; i++)
286
+ {
287
+ for (int j = 0; j < 2; j++)
288
+ {
289
+ // ac - bd
290
+ for (int k = 0; k < 2; k++)
291
+ {
292
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
293
+ }
294
+ }
295
+ }
296
+
297
+ for (int i = 0; i < K; i++)
298
+ {
299
+ for (int j = 0; j < 2; j++)
300
+ {
301
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
302
+ }
303
+ }
304
+ }
305
+
306
+ __syncthreads();
307
+
308
+ #pragma unroll
309
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
310
+ {
311
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
312
+ shared_offset = i * 32 + threadIdx.x;
313
+
314
+ if(idx < max_idx){
315
+ if(out_gate != nullptr){
316
+ out_real[idx + out_offset] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[shared_offset], out_gate[idx + out_offset]);
317
+ }else{
318
+ out_real[idx + out_offset] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
319
+ }
320
+ }
321
+
322
+ }
323
+ }
324
+
325
+
326
+ template <int TILE_H, int K>
327
+ __global__ void butterfly_ifft_padded_cuda_kernel_128(
328
+ const __half2 *__restrict__ x_real,
329
+ const __half2 *__restrict__ x_imag,
330
+ const complex_half_t *__restrict__ d_f,
331
+ const __half2 *__restrict__ twiddle_factors_real,
332
+ const __half2 *__restrict__ twiddle_factors_imag,
333
+ __half2 *__restrict__ out_real,
334
+ __half2 *__restrict__ out_gate,
335
+ uint B,
336
+ uint H,
337
+ int M)
338
+ {
339
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
340
+ const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
341
+ const int in_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * TILE_H * 128 * 32 * 2 * gridDim.x;
342
+ const int N = 128;
343
+ int idx;
344
+ int t_offset;
345
+ int out_t_offset;
346
+ int shared_offset;
347
+
348
+
349
+ extern __shared__ half real_shared[];
350
+ half *imag_shared = &real_shared[128 * 128];
351
+ half *real_shared_2 = &imag_shared[128 * 128];
352
+ half *imag_shared_2 = &real_shared_2[128 * 128];
353
+
354
+ half tmp_real, tmp_imag;
355
+
356
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[K][8];
357
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
358
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
359
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
360
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
361
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K];
362
+
363
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
364
+ {
365
+ for(int j=0; j< 4; j++){
366
+ shared_offset = i * 128 + threadIdx.x + j * blockDim.x;
367
+ real_shared_2[shared_offset] = d_f[shared_offset].real();
368
+ imag_shared_2[shared_offset] = d_f[shared_offset].imag();
369
+ }
370
+ }
371
+
372
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
373
+ {
374
+ for(int j=0; j< 2; j++){
375
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
376
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
377
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[idx];
378
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[idx];
379
+ }
380
+ }
381
+
382
+ __syncthreads();
383
+
384
+
385
+ for (int i = 0; i < 8; i++){
386
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
387
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
388
+ }
389
+
390
+ __syncthreads();
391
+
392
+ for (int t = 0; t < TILE_H; t++)
393
+ {
394
+
395
+ out_t_offset = t * M/2;
396
+ t_offset = t * 128 * 32 * 2 * gridDim.x;
397
+
398
+ for (int i = 0; i < K; i++){
399
+ for (int j = 0; j < 8; j++){
400
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
401
+ }
402
+ }
403
+
404
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
405
+ {
406
+ for(int j=0; j< 2; j++){
407
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
408
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
409
+ reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
410
+ reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
411
+ }
412
+ }
413
+
414
+ __syncthreads();
415
+
416
+ for (int i = 0; i < 8; i++)
417
+ {
418
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
419
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
420
+ }
421
+
422
+
423
+ for (int j = 0; j < 8; j++)
424
+ {
425
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
426
+ {
427
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
428
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
429
+ b_frag_real[j].x[k] = tmp_real;
430
+ b_frag_imag[j].x[k] = tmp_imag;
431
+ }
432
+ }
433
+
434
+ for (int i = 0; i < K; i++)
435
+ {
436
+ wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
437
+
438
+ // bd
439
+ #pragma unroll
440
+ for (int k = 0; k < 8; k++)
441
+ {
442
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
443
+ }
444
+
445
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
446
+ {
447
+ acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
448
+ }
449
+ }
450
+
451
+ for (int i = 0; i < K; i++){
452
+ for (int j = 0; j < 8; j++){
453
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
454
+ }
455
+ }
456
+
457
+ for (int i = 0; i < K; i++)
458
+ {
459
+ // ac - bd
460
+ #pragma unroll
461
+ for (int k = 0; k < 8; k++)
462
+ {
463
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
464
+ }
465
+ }
466
+
467
+ #pragma unroll
468
+ for (int i = 0; i < K; i++)
469
+ {
470
+ //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
471
+ wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
472
+ }
473
+
474
+ __syncthreads();
475
+
476
+ #pragma unroll
477
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
478
+ {
479
+ for(int j=0; j< 2; j++){
480
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
481
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
482
+ if(idx < max_idx){
483
+ if(out_gate != nullptr){
484
+ out_real[idx + out_offset + out_t_offset] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[idx + out_offset + out_t_offset]);
485
+ }else{
486
+ out_real[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
487
+ }
488
+ }
489
+ }
490
+ }
491
+
492
+ __syncthreads();
493
+ }
494
+ }
495
+
496
+
497
+ __global__ void butterfly_ifft_padded_cuda_kernel_16(
498
+ const __half2 *__restrict__ x_real,
499
+ const __half2 *__restrict__ x_imag,
500
+ const complex_half_t *__restrict__ d_f,
501
+ const __half2 *__restrict__ twiddle_factors_real,
502
+ const __half2 *__restrict__ twiddle_factors_imag,
503
+ __half2 *__restrict__ out_real,
504
+ __half2 *__restrict__ out_gate,
505
+ uint B,
506
+ uint H,
507
+ int M)
508
+ {
509
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
510
+ const int N = 16;
511
+ const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
512
+ const int offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
513
+
514
+ __shared__ half x_real_shared[N * 64];
515
+ __shared__ half x_imag_shared[N * 64];
516
+ __shared__ half d_f_real[N * N];
517
+ __shared__ half d_f_imag[N * N];
518
+ __shared__ half twiddles_real_shared[N * 64];
519
+ __shared__ half twiddles_imag_shared[N * 64];
520
+ __shared__ half out_real_shared[N * 64];
521
+
522
+ // #pragma unroll
523
+ for (int i = threadIdx.y; i < N; i++)
524
+ {
525
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
526
+ int shared_offset = i * blockDim.x + threadIdx.x;
527
+ reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
528
+ reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
529
+ reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
530
+ reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
531
+
532
+ if(threadIdx.x < 16 ){
533
+ shared_offset = i * 16 + threadIdx.x;
534
+ d_f_real[shared_offset] = d_f[shared_offset].real();
535
+ d_f_imag[shared_offset] = d_f[shared_offset].imag();
536
+ }
537
+ }
538
+
539
+ __syncthreads();
540
+
541
+ //check if it is better to have one warp do all the multiplication or split between warps
542
+ if (threadIdx.y < 4)
543
+ {
544
+ half tmp_real, tmp_imag;
545
+
546
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
547
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
548
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
549
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
550
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
551
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
552
+ wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
553
+
554
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
555
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
556
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
557
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
558
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
559
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
560
+
561
+
562
+
563
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
564
+ {
565
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
566
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
567
+ b_frag_real.x[k] = tmp_real;
568
+ b_frag_imag.x[k] = tmp_imag;
569
+ }
570
+
571
+
572
+ wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
573
+
574
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
575
+
576
+ for(int k=0; k< acc_frag_real.num_elements; k++){
577
+ acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
578
+ }
579
+
580
+
581
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
582
+
583
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
584
+
585
+ }
586
+
587
+ __syncthreads();
588
+
589
+ #pragma unroll
590
+ for (int i = threadIdx.y; i < N; i++)
591
+ {
592
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
593
+ if(idx < max_idx){
594
+ if(out_gate != nullptr){
595
+ out_real[out_offset + idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x], out_gate[out_offset + idx]);
596
+ }
597
+ else{
598
+ out_real[out_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x];
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ torch::Tensor butterfly_ifft_padded_cuda(
605
+ torch::Tensor x_real,
606
+ torch::Tensor x_imag,
607
+ torch::Tensor d_f,
608
+ torch::Tensor twiddle_factors_real,
609
+ torch::Tensor twiddle_factors_imag,
610
+ int fft_size,
611
+ std::optional<at::Tensor> out_gate = std::nullopt
612
+ )
613
+ {
614
+
615
+ uint B = x_real.size(0);
616
+ uint H = x_real.size(1);
617
+ uint N_M = x_real.size(2);
618
+ const int d_f_size = d_f.size(0);
619
+ // const int TILE_SIZE = 16;
620
+
621
+ dim3 gridDim;
622
+ dim3 blockDim;
623
+
624
+ // uint N = x_real.size(2);
625
+ gridDim.y = B;
626
+
627
+ blockDim.x = 32;
628
+ blockDim.y = 4;
629
+ gridDim.x = 512 / (32 * 1024/ (N_M / d_f_size));
630
+ gridDim.z = H;
631
+
632
+ const int TILE_H = 16;
633
+ torch::Tensor out_real = torch::empty({B, H, fft_size}, x_real.options());
634
+ const int K = ceil(fft_size / (1.0 * 16 * (N_M / d_f_size)));
635
+
636
+ switch(d_f_size){
637
+ case 16:
638
+ butterfly_ifft_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
639
+ static_cast<__half2 *>(x_real.data_ptr()),
640
+ static_cast<__half2 *>(x_imag.data_ptr()),
641
+ static_cast<complex_half_t *>(d_f.data_ptr()),
642
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
643
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
644
+ static_cast<__half2 *>(out_real.data_ptr()),
645
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
646
+ B,
647
+ H,
648
+ fft_size
649
+ );
650
+ break;
651
+ case 32:
652
+ switch (K)
653
+ {
654
+ case 1:
655
+ butterfly_ifft_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
656
+ static_cast<__half2 *>(x_real.data_ptr()),
657
+ static_cast<__half2 *>(x_imag.data_ptr()),
658
+ static_cast<complex_half_t *>(d_f.data_ptr()),
659
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
660
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
661
+ static_cast<__half2 *>(out_real.data_ptr()),
662
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
663
+ B,
664
+ H,
665
+ fft_size
666
+ );
667
+ break;
668
+ case 2:
669
+ butterfly_ifft_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
670
+ static_cast<__half2 *>(x_real.data_ptr()),
671
+ static_cast<__half2 *>(x_imag.data_ptr()),
672
+ static_cast<complex_half_t *>(d_f.data_ptr()),
673
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
674
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
675
+ static_cast<__half2 *>(out_real.data_ptr()),
676
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
677
+ B,
678
+ H,
679
+ fft_size
680
+ );
681
+ break;
682
+ default:
683
+ printf("Invalid K: %d\n", K);
684
+ break;
685
+ }
686
+ break;
687
+
688
+ case 64:
689
+ gridDim.z = H / TILE_H;
690
+ switch (K)
691
+ {
692
+ case 1:
693
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
694
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1><<<gridDim, blockDim, 65536>>>(
695
+ static_cast<__half2 *>(x_real.data_ptr()),
696
+ static_cast<__half2 *>(x_imag.data_ptr()),
697
+ static_cast<complex_half_t *>(d_f.data_ptr()),
698
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
699
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
700
+ static_cast<__half2 *>(out_real.data_ptr()),
701
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
702
+ B,
703
+ H,
704
+ fft_size);
705
+ break;
706
+
707
+ case 2:
708
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
709
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2><<<gridDim, blockDim, 65536>>>(
710
+ static_cast<__half2 *>(x_real.data_ptr()),
711
+ static_cast<__half2 *>(x_imag.data_ptr()),
712
+ static_cast<complex_half_t *>(d_f.data_ptr()),
713
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
714
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
715
+ static_cast<__half2 *>(out_real.data_ptr()),
716
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
717
+ B,
718
+ H,
719
+ fft_size);
720
+ break;
721
+
722
+ case 3:
723
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
724
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3><<<gridDim, blockDim, 65536>>>(
725
+ static_cast<__half2 *>(x_real.data_ptr()),
726
+ static_cast<__half2 *>(x_imag.data_ptr()),
727
+ static_cast<complex_half_t *>(d_f.data_ptr()),
728
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
729
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
730
+ static_cast<__half2 *>(out_real.data_ptr()),
731
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
732
+ B,
733
+ H,
734
+ fft_size);
735
+ break;
736
+
737
+ case 4:
738
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
739
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4><<<gridDim, blockDim, 65536>>>(
740
+ static_cast<__half2 *>(x_real.data_ptr()),
741
+ static_cast<__half2 *>(x_imag.data_ptr()),
742
+ static_cast<complex_half_t *>(d_f.data_ptr()),
743
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
744
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
745
+ static_cast<__half2 *>(out_real.data_ptr()),
746
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
747
+ B,
748
+ H,
749
+ fft_size);
750
+ break;
751
+
752
+ default:
753
+ break;
754
+ }
755
+
756
+ break;
757
+ case 128:
758
+ blockDim.x = 32;
759
+ blockDim.y = 8;
760
+ gridDim.x = 256 / (32 * 1024/ (N_M / d_f_size));
761
+ gridDim.z = H / TILE_H;
762
+
763
+ switch (K)
764
+ {
765
+ case 1:
766
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
767
+
768
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1><<<gridDim, blockDim, 65536 * 2>>>(
769
+ static_cast<__half2 *>(x_real.data_ptr()),
770
+ static_cast<__half2 *>(x_imag.data_ptr()),
771
+ static_cast<complex_half_t *>(d_f.data_ptr()),
772
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
773
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
774
+ static_cast<__half2 *>(out_real.data_ptr()),
775
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
776
+ B,
777
+ H,
778
+ fft_size);
779
+ break;
780
+
781
+ case 2:
782
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
783
+
784
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2><<<gridDim, blockDim, 65536 * 2>>>(
785
+ static_cast<__half2 *>(x_real.data_ptr()),
786
+ static_cast<__half2 *>(x_imag.data_ptr()),
787
+ static_cast<complex_half_t *>(d_f.data_ptr()),
788
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
789
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
790
+ static_cast<__half2 *>(out_real.data_ptr()),
791
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
792
+ B,
793
+ H,
794
+ fft_size);
795
+ break;
796
+
797
+ case 3:
798
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
799
+
800
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3><<<gridDim, blockDim, 65536 * 2>>>(
801
+ static_cast<__half2 *>(x_real.data_ptr()),
802
+ static_cast<__half2 *>(x_imag.data_ptr()),
803
+ static_cast<complex_half_t *>(d_f.data_ptr()),
804
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
805
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
806
+ static_cast<__half2 *>(out_real.data_ptr()),
807
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
808
+ B,
809
+ H,
810
+ fft_size);
811
+ break;
812
+
813
+ case 4:
814
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
815
+
816
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4><<<gridDim, blockDim, 65536 * 2>>>(
817
+ static_cast<__half2 *>(x_real.data_ptr()),
818
+ static_cast<__half2 *>(x_imag.data_ptr()),
819
+ static_cast<complex_half_t *>(d_f.data_ptr()),
820
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
821
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
822
+ static_cast<__half2 *>(out_real.data_ptr()),
823
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
824
+ B,
825
+ H,
826
+ fft_size);
827
+ break;
828
+
829
+ case 5:
830
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
831
+
832
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5><<<gridDim, blockDim, 65536 * 2>>>(
833
+ static_cast<__half2 *>(x_real.data_ptr()),
834
+ static_cast<__half2 *>(x_imag.data_ptr()),
835
+ static_cast<complex_half_t *>(d_f.data_ptr()),
836
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
837
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
838
+ static_cast<__half2 *>(out_real.data_ptr()),
839
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
840
+ B,
841
+ H,
842
+ fft_size);
843
+ break;
844
+
845
+ case 6:
846
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
847
+
848
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6><<<gridDim, blockDim, 65536 * 2>>>(
849
+ static_cast<__half2 *>(x_real.data_ptr()),
850
+ static_cast<__half2 *>(x_imag.data_ptr()),
851
+ static_cast<complex_half_t *>(d_f.data_ptr()),
852
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
853
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
854
+ static_cast<__half2 *>(out_real.data_ptr()),
855
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
856
+ B,
857
+ H,
858
+ fft_size);
859
+ break;
860
+
861
+ case 7:
862
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
863
+
864
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7><<<gridDim, blockDim, 65536 * 2>>>(
865
+ static_cast<__half2 *>(x_real.data_ptr()),
866
+ static_cast<__half2 *>(x_imag.data_ptr()),
867
+ static_cast<complex_half_t *>(d_f.data_ptr()),
868
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
869
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
870
+ static_cast<__half2 *>(out_real.data_ptr()),
871
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
872
+ B,
873
+ H,
874
+ fft_size);
875
+ break;
876
+
877
+ case 8:
878
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
879
+
880
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8><<<gridDim, blockDim, 65536 * 2>>>(
881
+ static_cast<__half2 *>(x_real.data_ptr()),
882
+ static_cast<__half2 *>(x_imag.data_ptr()),
883
+ static_cast<complex_half_t *>(d_f.data_ptr()),
884
+ static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
885
+ static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
886
+ static_cast<__half2 *>(out_real.data_ptr()),
887
+ out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
888
+ B,
889
+ H,
890
+ fft_size);
891
+ break;
892
+
893
+ default:
894
+ printf("Invalid K: %d\n", K);
895
+ break;
896
+ }
897
+ break;
898
+
899
+ default:
900
+ printf("Invalid d_f_size: %d\n", d_f_size);
901
+ break;
902
+ }
903
+
904
+ return out_real;
905
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda_bf16.cu ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cuda_bf16.h>
10
+ #include "shared.h"
11
+
12
+ using namespace nvcuda;
13
+
14
+ template <int TILE_H, int K>
15
+ __global__ void butterfly_ifft_padded_cuda_kernel_64(
16
+ const __nv_bfloat162 *__restrict__ x_real,
17
+ const __nv_bfloat162 *__restrict__ x_imag,
18
+ const __nv_bfloat162 *__restrict__ d_f_real,
19
+ const __nv_bfloat162 *__restrict__ d_f_imag,
20
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
21
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
22
+ __nv_bfloat162 *__restrict__ out_real,
23
+ __nv_bfloat162 *__restrict__ out_gate,
24
+ uint B,
25
+ uint H,
26
+ int M)
27
+ {
28
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
29
+ const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
30
+ const int in_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * TILE_H * 64 * 32 * gridDim.x;
31
+ int idx;
32
+ int t_offset;
33
+ int out_t_offset;
34
+ int shared_offset;
35
+ const int N = 64;
36
+
37
+ extern __shared__ __nv_bfloat16 x_real_shared[];
38
+ __nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
39
+ __nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
40
+ __nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
41
+ __nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
42
+ __nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
43
+ float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
44
+
45
+ __nv_bfloat16 tmp_real, tmp_imag;
46
+
47
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[K][4];
48
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[K][4];
49
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
50
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
51
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
52
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
53
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K];
54
+
55
+ // #pragma unroll
56
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
57
+ {
58
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
59
+ shared_offset = i * 32 + threadIdx.x;
60
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
61
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
62
+
63
+ // #pragma unroll
64
+ shared_offset = i * 32 + threadIdx.x;
65
+ reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
66
+ reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
67
+ }
68
+
69
+ __syncthreads();
70
+
71
+ for (int i = 0; i < 4; i++)
72
+ {
73
+ if(i < K){
74
+ #pragma unroll
75
+ for (int j = 0; j < 4; j++)
76
+ {
77
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
78
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
79
+ }
80
+ }
81
+ wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
82
+ wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
83
+ }
84
+
85
+ for (int t = 0; t < TILE_H; t++)
86
+ {
87
+
88
+ out_t_offset = t * M/2;
89
+ t_offset = t * 64 * 32 * gridDim.x;
90
+
91
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
92
+ {
93
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
94
+ shared_offset = i * 32 + threadIdx.x;
95
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
96
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
97
+ }
98
+
99
+ __syncthreads();
100
+
101
+ for (int i = 0; i < 4; i++)
102
+ {
103
+ wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
104
+ wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
105
+ }
106
+
107
+ for (int j = 0; j < 4; j++)
108
+ {
109
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
110
+ {
111
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
112
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
113
+ b_frag_real[j].x[k] = tmp_real;
114
+ b_frag_imag[j].x[k] = tmp_imag;
115
+ }
116
+ }
117
+
118
+ for (int i = 0; i < K; i++)
119
+ {
120
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
121
+
122
+ // bd
123
+ #pragma unroll
124
+ for (int k = 0; k < 4; k++)
125
+ {
126
+ wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
127
+ }
128
+
129
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
130
+ {
131
+ acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
132
+ }
133
+ }
134
+
135
+ for (int i = 0; i < K; i++)
136
+ {
137
+ // ac - bd
138
+ #pragma unroll
139
+ for (int k = 0; k < 4; k++)
140
+ {
141
+ wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
142
+ }
143
+ }
144
+
145
+ #pragma unroll
146
+ for (int i = 0; i < K; i++)
147
+ {
148
+ wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
149
+ }
150
+
151
+ __syncthreads();
152
+
153
+ #pragma unroll
154
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
155
+ {
156
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
157
+ shared_offset = i * 32 + threadIdx.x;
158
+
159
+ if(idx < max_idx){
160
+ if(out_gate != nullptr)
161
+ out_real[out_offset + out_t_offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]), out_gate[out_offset + out_t_offset + idx]);
162
+ else
163
+ out_real[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]);
164
+ }
165
+ }
166
+
167
+ __syncthreads();
168
+ }
169
+ }
170
+
171
+
172
+ template <int K>
173
+ __global__ void butterfly_ifft_padded_cuda_kernel_32(
174
+ const __nv_bfloat162 *__restrict__ x_real,
175
+ const __nv_bfloat162 *__restrict__ x_imag,
176
+ const __nv_bfloat16 *__restrict__ d_f_real,
177
+ const __nv_bfloat16 *__restrict__ d_f_imag,
178
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
179
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
180
+ __nv_bfloat162 *__restrict__ out_real,
181
+ __nv_bfloat162 *__restrict__ out_gate,
182
+ uint B,
183
+ uint H,
184
+ int M)
185
+ {
186
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
187
+ const int N = 32;
188
+ int idx;
189
+ int shared_offset;
190
+
191
+ const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
192
+ const int in_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
193
+
194
+
195
+ __shared__ __nv_bfloat16 x_real_shared[32 * 64];
196
+ __shared__ __nv_bfloat16 x_imag_shared[32 * 64];
197
+ __shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
198
+ __shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
199
+ __shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
200
+ __shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
201
+ __shared__ float out_real_shared[32 * 64];
202
+
203
+ // #pragma unroll
204
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
205
+ {
206
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
207
+ int shared_offset = i * 32 + threadIdx.x;
208
+
209
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[in_offset + idx];
210
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[in_offset + idx];
211
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
212
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
213
+
214
+ // #pragma unroll
215
+ shared_offset = i * 32 + threadIdx.x;
216
+ d_f_real_shared[shared_offset] = d_f_real[shared_offset];
217
+ d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
218
+ }
219
+
220
+ __syncthreads();
221
+
222
+ if (threadIdx.y < N/16)
223
+ {
224
+ __nv_bfloat16 tmp_real, tmp_imag;
225
+
226
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[K][2];
227
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[K][2];
228
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
229
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
230
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
231
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
232
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K][2];
233
+
234
+ int t = threadIdx.y * 32;
235
+
236
+ for (int i = 0; i < 2; i++)
237
+ {
238
+ for (int j = 0; j < 2; j++)
239
+ {
240
+ if(i < K){
241
+ wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
242
+ wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
243
+ }
244
+ wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
245
+ wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
246
+ wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
247
+ wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
248
+ }
249
+ }
250
+
251
+ for (int i = 0; i < 2; i++)
252
+ {
253
+ for (int j = 0; j < 2; j++)
254
+ {
255
+ for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
256
+ {
257
+ tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
258
+ tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
259
+ b_frag_real[i][j].x[k] = tmp_real;
260
+ b_frag_imag[i][j].x[k] = tmp_imag;
261
+ }
262
+ }
263
+ }
264
+
265
+ for (int i = 0; i < K; i++)
266
+ {
267
+ for (int j = 0; j < 2; j++)
268
+ {
269
+ wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
270
+
271
+ // bd
272
+ for (int k = 0; k < 2; k++)
273
+ {
274
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
275
+ }
276
+
277
+ for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
278
+ {
279
+ acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
280
+ }
281
+ }
282
+ }
283
+
284
+ for (int i = 0; i < K; i++)
285
+ {
286
+ for (int j = 0; j < 2; j++)
287
+ {
288
+ // ac - bd
289
+ for (int k = 0; k < 2; k++)
290
+ {
291
+ wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
292
+ }
293
+ }
294
+ }
295
+
296
+ for (int i = 0; i < K; i++)
297
+ {
298
+ for (int j = 0; j < 2; j++)
299
+ {
300
+ wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
301
+ }
302
+ }
303
+ }
304
+
305
+ __syncthreads();
306
+
307
+ #pragma unroll
308
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
309
+ {
310
+ idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
311
+ shared_offset = i * 32 + threadIdx.x;
312
+
313
+ if(idx < max_idx){
314
+ if(out_gate != nullptr){
315
+ out_real[idx + out_offset] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]), out_gate[idx + out_offset]);
316
+ }else{
317
+ out_real[idx + out_offset] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]);
318
+ }
319
+ }
320
+
321
+ }
322
+ }
323
+
324
+
325
+ template <int TILE_H, int K>
326
+ __global__ void butterfly_ifft_padded_cuda_kernel_128(
327
+ const __nv_bfloat162 *__restrict__ x_real,
328
+ const __nv_bfloat162 *__restrict__ x_imag,
329
+ const __nv_bfloat162 *__restrict__ d_f_real,
330
+ const __nv_bfloat162 *__restrict__ d_f_imag,
331
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
332
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
333
+ __nv_bfloat162 *__restrict__ out_real,
334
+ __nv_bfloat162 *__restrict__ out_gate,
335
+ uint B,
336
+ uint H,
337
+ int M)
338
+ {
339
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
340
+ const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
341
+ const int in_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * TILE_H * 128 * 32 * 2 * gridDim.x;
342
+ const int N = 128;
343
+ int idx;
344
+ int t_offset;
345
+ int out_t_offset;
346
+ int shared_offset;
347
+
348
+
349
+ extern __shared__ __nv_bfloat16 real_shared[];
350
+ __nv_bfloat16 *imag_shared = &real_shared[128 * 128];
351
+ __nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
352
+ __nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
353
+
354
+ __nv_bfloat16 tmp_real, tmp_imag;
355
+
356
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[K][8];
357
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
358
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
359
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
360
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
361
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K];
362
+
363
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
364
+ {
365
+ for(int j=0; j< 2; j++){
366
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
367
+ reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
368
+ reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
369
+ }
370
+ }
371
+
372
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
373
+ {
374
+ for(int j=0; j< 2; j++){
375
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
376
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
377
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[idx];
378
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[idx];
379
+ }
380
+ }
381
+
382
+ __syncthreads();
383
+
384
+
385
+ for (int i = 0; i < 8; i++){
386
+ wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
387
+ wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
388
+ }
389
+
390
+
391
+ for (int t = 0; t < TILE_H; t++)
392
+ {
393
+
394
+ out_t_offset = t * M/2;
395
+ t_offset = t * 128 * 32 * 2 * gridDim.x;
396
+
397
+ for (int i = 0; i < K; i++){
398
+ for (int j = 0; j < 8; j++){
399
+ wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
400
+ }
401
+ }
402
+
403
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
404
+ {
405
+ for(int j=0; j< 2; j++){
406
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
407
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
408
+ reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
409
+ reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
410
+ }
411
+ }
412
+
413
+ __syncthreads();
414
+
415
+ for (int i = 0; i < 8; i++)
416
+ {
417
+ wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
418
+ wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
419
+ }
420
+
421
+
422
+ __syncthreads();
423
+
424
+ for (int j = 0; j < 8; j++)
425
+ {
426
+ for (int k = 0; k < tw_frag_real[j].num_elements; k++)
427
+ {
428
+ tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
429
+ tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
430
+ b_frag_real[j].x[k] = tmp_real;
431
+ b_frag_imag[j].x[k] = tmp_imag;
432
+ }
433
+ }
434
+
435
+ __syncthreads();
436
+
437
+ for (int i = 0; i < K; i++)
438
+ {
439
+ wmma::fill_fragment(acc_frag_real[i], 0.0f);
440
+
441
+ // bd
442
+ #pragma unroll
443
+ for (int k = 0; k < 8; k++)
444
+ {
445
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
446
+ }
447
+
448
+ for (int k = 0; k < acc_frag_real[i].num_elements; k++)
449
+ {
450
+ acc_frag_real[i].x[k] = -acc_frag_real[i].x[k];
451
+ }
452
+ }
453
+
454
+ for (int i = 0; i < K; i++){
455
+ for (int j = 0; j < 8; j++){
456
+ wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
457
+ }
458
+ }
459
+
460
+ for (int i = 0; i < K; i++)
461
+ {
462
+ // ac - bd
463
+ #pragma unroll
464
+ for (int k = 0; k < 8; k++)
465
+ {
466
+ wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
467
+ }
468
+ }
469
+
470
+ __syncthreads();
471
+
472
+ #pragma unroll
473
+ for (int i = 0; i < K; i++)
474
+ {
475
+ //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
476
+ wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
477
+ }
478
+
479
+ __syncthreads();
480
+
481
+ #pragma unroll
482
+ for (int i = threadIdx.y; i < N; i+=blockDim.y)
483
+ {
484
+ for(int j=0; j< 2; j++){
485
+ idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
486
+ shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
487
+ if(idx < max_idx){
488
+ if(out_gate != nullptr){
489
+ out_real[idx + out_offset + out_t_offset] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[idx + out_offset + out_t_offset]);
490
+ }else{
491
+ out_real[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
492
+ }
493
+ }
494
+ }
495
+ }
496
+
497
+ __syncthreads();
498
+ }
499
+ }
500
+
501
+
502
+ __global__ void butterfly_ifft_padded_cuda_kernel_16(
503
+ const __nv_bfloat162 *__restrict__ x_real,
504
+ const __nv_bfloat162 *__restrict__ x_imag,
505
+ const __nv_bfloat16 *__restrict__ d_f_real,
506
+ const __nv_bfloat16 *__restrict__ d_f_imag,
507
+ const __nv_bfloat162 *__restrict__ twiddle_factors_real,
508
+ const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
509
+ __nv_bfloat162 *__restrict__ out_real,
510
+ __nv_bfloat162 *__restrict__ out_gate,
511
+ uint B,
512
+ uint H,
513
+ int M)
514
+ {
515
+ const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
516
+ const int N = 16;
517
+ const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
518
+ const int offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
519
+
520
+ __shared__ __nv_bfloat16 x_real_shared[N * 64];
521
+ __shared__ __nv_bfloat16 x_imag_shared[N * 64];
522
+ __shared__ __nv_bfloat16 twiddles_real_shared[N * 64];
523
+ __shared__ __nv_bfloat16 twiddles_imag_shared[N * 64];
524
+ __shared__ float out_real_shared[N * 64];
525
+
526
+ // #pragma unroll
527
+ for (int i = threadIdx.y; i < N; i++)
528
+ {
529
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
530
+ int shared_offset = i * blockDim.x + threadIdx.x;
531
+ reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
532
+ reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
533
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
534
+ reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
535
+ }
536
+
537
+ __syncthreads();
538
+
539
+ if (threadIdx.y < 4)
540
+ {
541
+ __nv_bfloat16 tmp_real, tmp_imag;
542
+
543
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
544
+ wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
545
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
546
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
547
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
548
+ wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
549
+ wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
550
+
551
+ wmma::load_matrix_sync(a_frag_real, d_f_real, N);
552
+ wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
553
+ wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
554
+ wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
555
+ wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
556
+ wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
557
+
558
+
559
+ for (int k = 0; k < tw_frag_real.num_elements; k++)
560
+ {
561
+ tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
562
+ tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
563
+ b_frag_real.x[k] = tmp_real;
564
+ b_frag_imag.x[k] = tmp_imag;
565
+ }
566
+
567
+
568
+
569
+ wmma::fill_fragment(acc_frag_real, 0.0f);
570
+
571
+ wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
572
+
573
+ for(int k=0; k< acc_frag_real.num_elements; k++){
574
+ acc_frag_real.x[k] = - acc_frag_real.x[k];
575
+ }
576
+
577
+ wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
578
+
579
+ wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
580
+
581
+ }
582
+
583
+ __syncthreads();
584
+
585
+ #pragma unroll
586
+ for (int i = threadIdx.y; i < N; i++)
587
+ {
588
+ int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
589
+ if(idx < max_idx){
590
+ if(out_gate != nullptr){
591
+ out_real[out_offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[i * 32 + threadIdx.x]), out_gate[out_offset + idx]);
592
+ }else{
593
+ out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[i * 32 + threadIdx.x]);
594
+ }
595
+ }
596
+ }
597
+ }
598
+
599
+
600
+ torch::Tensor butterfly_ifft_padded_bf16_cuda(
601
+ torch::Tensor x_real,
602
+ torch::Tensor x_imag,
603
+ torch::Tensor d_f_real,
604
+ torch::Tensor d_f_imag,
605
+ torch::Tensor twiddle_factors_real,
606
+ torch::Tensor twiddle_factors_imag,
607
+ int fft_size,
608
+ std::optional<at::Tensor> out_gate = std::nullopt
609
+ )
610
+ {
611
+
612
+ uint B = x_real.size(0);
613
+ uint H = x_real.size(1);
614
+ uint N_M = x_real.size(2);
615
+ const int d_f_size = d_f_real.size(0);
616
+ // const int TILE_SIZE = 16;
617
+
618
+ dim3 gridDim;
619
+ dim3 blockDim;
620
+
621
+ // uint N = x_real.size(2);
622
+ gridDim.y = B;
623
+
624
+ blockDim.x = 32;
625
+ blockDim.y = 4;
626
+ gridDim.x = 512 / (32 * 1024/ (N_M / d_f_size));
627
+ gridDim.z = H;
628
+
629
+ const int TILE_H = 16;
630
+ torch::Tensor out_real = torch::empty({B, H, fft_size}, x_real.options());
631
+ const int K = ceil(fft_size / (1.0 * 16 * (N_M / d_f_size)));
632
+
633
+ switch(d_f_size){
634
+ case 16:
635
+ butterfly_ifft_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
636
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
637
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
638
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
639
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
640
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
641
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
642
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
643
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
644
+ B,
645
+ H,
646
+ fft_size
647
+ );
648
+ break;
649
+ case 32:
650
+ switch (K)
651
+ {
652
+ case 1:
653
+ butterfly_ifft_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
654
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
655
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
656
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
657
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
658
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
659
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
660
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
661
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
662
+ B,
663
+ H,
664
+ fft_size
665
+ );
666
+ break;
667
+ case 2:
668
+ butterfly_ifft_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
669
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
670
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
671
+ static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
672
+ static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
673
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
674
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
675
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
676
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
677
+ B,
678
+ H,
679
+ fft_size
680
+ );
681
+ break;
682
+ default:
683
+ printf("Invalid K: %d\n", K);
684
+ break;
685
+ }
686
+ break;
687
+
688
+ case 64:
689
+ gridDim.z = H / TILE_H;
690
+ switch (K)
691
+ {
692
+ case 1:
693
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
694
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1><<<gridDim, blockDim, 65536>>>(
695
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
696
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
697
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
698
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
699
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
700
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
701
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
702
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
703
+ B,
704
+ H,
705
+ fft_size);
706
+ break;
707
+
708
+ case 2:
709
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
710
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2><<<gridDim, blockDim, 65536>>>(
711
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
712
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
713
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
714
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
715
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
716
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
717
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
718
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
719
+ B,
720
+ H,
721
+ fft_size);
722
+ break;
723
+
724
+ case 3:
725
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
726
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3><<<gridDim, blockDim, 65536>>>(
727
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
728
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
729
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
730
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
731
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
732
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
733
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
734
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
735
+ B,
736
+ H,
737
+ fft_size);
738
+ break;
739
+
740
+ case 4:
741
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
742
+ butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4><<<gridDim, blockDim, 65536>>>(
743
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
744
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
745
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
746
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
747
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
748
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
749
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
750
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
751
+ B,
752
+ H,
753
+ fft_size);
754
+ break;
755
+
756
+ default:
757
+ break;
758
+ }
759
+
760
+ break;
761
+ case 128:
762
+ blockDim.x = 32;
763
+ blockDim.y = 8;
764
+ gridDim.x = 256 / (32 * 1024/ (N_M / d_f_size));
765
+ gridDim.z = H / TILE_H;
766
+
767
+ switch (K)
768
+ {
769
+ case 1:
770
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
771
+
772
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1><<<gridDim, blockDim, 65536 * 2>>>(
773
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
774
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
775
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
776
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
777
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
778
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
779
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
780
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
781
+ B,
782
+ H,
783
+ fft_size);
784
+ break;
785
+
786
+ case 2:
787
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
788
+
789
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2><<<gridDim, blockDim, 65536 * 2>>>(
790
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
791
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
792
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
793
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
794
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
795
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
796
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
797
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
798
+ B,
799
+ H,
800
+ fft_size);
801
+ break;
802
+
803
+ case 3:
804
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
805
+
806
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3><<<gridDim, blockDim, 65536 * 2>>>(
807
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
808
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
809
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
810
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
811
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
812
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
813
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
814
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
815
+ B,
816
+ H,
817
+ fft_size);
818
+ break;
819
+
820
+ case 4:
821
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
822
+
823
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4><<<gridDim, blockDim, 65536 * 2>>>(
824
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
825
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
826
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
827
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
828
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
829
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
830
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
831
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
832
+ B,
833
+ H,
834
+ fft_size);
835
+ break;
836
+
837
+ case 5:
838
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
839
+
840
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5><<<gridDim, blockDim, 65536 * 2>>>(
841
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
842
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
843
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
844
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
845
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
846
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
847
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
848
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
849
+ B,
850
+ H,
851
+ fft_size);
852
+ break;
853
+
854
+ case 6:
855
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
856
+
857
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6><<<gridDim, blockDim, 65536 * 2>>>(
858
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
859
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
860
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
861
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
862
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
863
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
864
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
865
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
866
+ B,
867
+ H,
868
+ fft_size);
869
+ break;
870
+
871
+ case 7:
872
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
873
+
874
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7><<<gridDim, blockDim, 65536 * 2>>>(
875
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
876
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
877
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
878
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
879
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
880
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
881
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
882
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
883
+ B,
884
+ H,
885
+ fft_size);
886
+ break;
887
+
888
+ case 8:
889
+ cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
890
+
891
+ butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8><<<gridDim, blockDim, 65536 * 2>>>(
892
+ static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
893
+ static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
894
+ static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
895
+ static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
896
+ static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
897
+ static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
898
+ static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
899
+ out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
900
+ B,
901
+ H,
902
+ fft_size);
903
+ break;
904
+
905
+ default:
906
+ printf("Invalid K: %d\n", K);
907
+ break;
908
+ }
909
+ break;
910
+
911
+ default:
912
+ printf("Invalid d_f_size: %d\n", d_f_size);
913
+ break;
914
+ }
915
+
916
+ return out_real;
917
+ }
overlay/kernels/cuda/flashfftconv/csrc/butterfly/shared.h ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+ #include <stdio.h>
7
+ #include <mma.h>
8
+ #include <cuda_fp16.h>
9
+ #include <cub/block/block_load.cuh>
10
+ #include <cub/block/block_store.cuh>
11
+ using namespace nvcuda;
12
+
13
+ using complex_half_t = typename c10::complex<at::Half>;
14
+ using complex_bhalf_t = typename c10::complex<at::BFloat16>;
15
+
16
+ #define WMMA_M 16
17
+ #define WMMA_N 16
18
+ #define WMMA_K 16
19
+ #define WARP_SIZE 32
20
+
21
+ #ifndef MONARCH_CUDA_H_
22
+ #define MONARCH_CUDA_H_
23
+
24
+ __device__ __forceinline__ float2
25
+
26
+ operator+( float2 lhs, float2 rhs)
27
+
28
+ {
29
+
30
+ float2 res = { lhs.x + rhs.x , lhs.y + rhs.y };
31
+
32
+ return res;
33
+
34
+ }
35
+
36
+
37
+ __device__ __forceinline__ float2
38
+
39
+ operator-( float2 lhs, float2 rhs)
40
+
41
+ {
42
+
43
+ float2 res = { lhs.x - rhs.x , lhs.y - rhs.y };
44
+
45
+ return res;
46
+
47
+ }
48
+
49
+ __device__ __forceinline__ float2
50
+
51
+ operator*( float2 lhs, float2 rhs)
52
+
53
+ {
54
+
55
+ float2 res = { lhs.x * rhs.x , lhs.y * rhs.y };
56
+
57
+ return res;
58
+
59
+ }
60
+ #endif
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d.h ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ #include <torch/extension.h>
4
+
5
+ #include <vector>
6
+
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16 || x.dtype() == torch::kFloat32, #x " must be float16 or bfloat16 or float32")
11
+ #define CHECK_SAME_TYPE(x, y) TORCH_CHECK(x.dtype() == y.dtype(), #x " and " #y " must have the same dtype")
12
+
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x); \
16
+ CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x)
17
+
18
+ torch::Tensor conv1d_cuda_bhl(
19
+ torch::Tensor u,
20
+ torch::Tensor weight,
21
+ torch::Tensor bias,
22
+ uint padding);
23
+
24
+ torch::Tensor conv1d_cuda_blh(
25
+ torch::Tensor u,
26
+ torch::Tensor weight,
27
+ torch::Tensor bias,
28
+ uint padding);
29
+
30
+ std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
31
+ torch::Tensor dout,
32
+ torch::Tensor input,
33
+ torch::Tensor weight,
34
+ torch::Tensor bias,
35
+ uint padding
36
+ );
37
+
38
+ std::vector<torch::Tensor> conv1d_backward_blh_cuda(
39
+ torch::Tensor dout,
40
+ torch::Tensor input,
41
+ torch::Tensor weight,
42
+ torch::Tensor bias,
43
+ uint padding
44
+ );
45
+
46
+
47
+ torch::Tensor conv1d_fwd(
48
+ torch::Tensor u,
49
+ torch::Tensor weight,
50
+ torch::Tensor bias,
51
+ uint padding,
52
+ bool is_bhl)
53
+ {
54
+ CHECK_INPUT(u);
55
+ CHECK_INPUT(weight);
56
+ CHECK_INPUT(bias);
57
+ CHECK_SAME_TYPE(weight, bias);
58
+
59
+ int k;
60
+
61
+ if(is_bhl){
62
+ k = weight.size(1);
63
+ }else{
64
+ k = weight.size(0);
65
+ }
66
+
67
+ TORCH_CHECK(k % 2 == 1, "Filter size must be odd number");
68
+
69
+ if(is_bhl){
70
+ return conv1d_cuda_bhl(u, weight, bias, padding);
71
+ }else{
72
+ return conv1d_cuda_blh(u, weight, bias, padding);
73
+ }
74
+ }
75
+
76
+ std::vector<torch::Tensor> conv1d_bwd(
77
+ torch::Tensor dout,
78
+ torch::Tensor input,
79
+ torch::Tensor weight,
80
+ torch::Tensor bias,
81
+ uint padding,
82
+ bool is_bhl)
83
+ {
84
+ CHECK_INPUT(dout);
85
+ CHECK_INPUT(input);
86
+ CHECK_INPUT(weight);
87
+ CHECK_INPUT(bias);
88
+ CHECK_SAME_TYPE(weight, bias);
89
+ CHECK_SAME_TYPE(dout, input);
90
+
91
+ if(is_bhl){
92
+ return conv1d_backward_bhl_cuda(dout, input, weight, bias, padding);
93
+ } else{
94
+ return conv1d_backward_blh_cuda(dout, input, weight, bias, padding);
95
+ }
96
+ }
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bhl.cu ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ // Simple 1D depthwise convolution implementation with dilation and stride = 1
4
+ #include "shared.h"
5
+
6
+ const uint BX = 256;
7
+ const uint BY = 1;
8
+ const uint BZ = 1;
9
+
10
+ const uint TILE_SIZE_L = 4;
11
+ const uint TILE_SIZE_D = 1;
12
+
13
+ template<typename T, typename U>
14
+ __forceinline__ __device__ T _conv1d_k_3(const T* u, const U* weights, const U* bias, uint padding, uint l, uint d, uint L, uint D, uint K)
15
+ {
16
+ T tmp;
17
+ T weight;
18
+
19
+ set_value(&tmp, bias[d]);
20
+
21
+ int idx = l - padding;
22
+
23
+ if(idx >= 0 && idx < L){
24
+ set_value(&weight, weights[0]);
25
+ tmp = __hfma(u[d * L + idx], weight, tmp);
26
+ }
27
+
28
+ idx++;
29
+ if(idx >= 0 && idx < L){
30
+ set_value(&weight, weights[1]);
31
+ tmp = __hfma(u[d * L + idx], weight, tmp);
32
+ }
33
+
34
+ idx++;
35
+ if(idx >= 0 && idx < L){
36
+ set_value(&weight, weights[2]);
37
+ tmp = __hfma(u[d * L + idx], weight, tmp);
38
+ }
39
+
40
+ return tmp;
41
+ }
42
+
43
+ template<typename T, typename U>
44
+ __global__ void conv1d_kernel(
45
+ const T *__restrict__ u,
46
+ const U *__restrict__ weights,
47
+ const U *__restrict__ bias,
48
+ T *__restrict__ out,
49
+ uint padding,
50
+ uint B,
51
+ uint L,
52
+ uint D,
53
+ uint K,
54
+ uint L_out
55
+ )
56
+ {
57
+ const int b = blockIdx.z * blockDim.z + threadIdx.z;
58
+ const int d = blockIdx.y * blockDim.y * TILE_SIZE_D + threadIdx.y;
59
+ const int l_offset = blockIdx.x * blockDim.x * TILE_SIZE_L + threadIdx.x;
60
+
61
+ T tmp;
62
+ T weight;
63
+
64
+ int idx;
65
+ int l;
66
+
67
+ for(int l_tile = 0; l_tile < TILE_SIZE_L; l_tile++){
68
+ l = l_offset + l_tile * blockDim.x;
69
+
70
+ set_value(&tmp, bias[d]);
71
+
72
+ if(d < D && l < L_out && b < B){
73
+ if(K == 3){
74
+ out[b * L_out * D + d * L_out + l] = _conv1d_k_3(u + b * L * D, weights + d * K, bias, padding, l, d, L, D, K);
75
+ } else{
76
+ for(int k = 0; k < K; k++){
77
+ idx = l - padding + k;
78
+ if(idx >= 0 && idx < L){
79
+ set_value(&weight, weights[d * K + k]);
80
+ tmp = __hfma(u[b * L_out * D + d * L + idx], weight, tmp);
81
+ }
82
+ }
83
+ out[b * L_out * D + d * L_out + l] = tmp;
84
+
85
+ }
86
+ }
87
+ }
88
+
89
+ }
90
+
91
+ torch::Tensor conv1d_cuda_bhl(
92
+ torch::Tensor u,
93
+ torch::Tensor weight,
94
+ torch::Tensor bias,
95
+ uint padding)
96
+ {
97
+ const uint b = u.size(0);
98
+ const uint d = u.size(1);
99
+ const uint l = u.size(2);
100
+
101
+
102
+ const uint k = weight.size(1);
103
+
104
+ uint l_out = (l + 2 * padding - k + 1);
105
+
106
+ dim3 blockDims(BX, BY, BZ);
107
+
108
+ dim3 gridDims(ceil(l_out * 1.0 / (BX * TILE_SIZE_L) ), ceil((d * 1.0) / (BY * TILE_SIZE_D)), ceil((b * 1.0) / BZ));
109
+
110
+ torch::Tensor out = torch::empty({b, d, l_out}, u.options());
111
+
112
+ DISPATCH_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), weight.scalar_type(),
113
+ "depthwise conv 1d fwd bhl",
114
+ ([&]
115
+ { conv1d_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
116
+ static_cast<input_t *>(u.data_ptr()),
117
+ static_cast<weight_t *>(weight.data_ptr()),
118
+ static_cast<weight_t *>(bias.data_ptr()),
119
+ static_cast<input_t *>(out.data_ptr()),
120
+ padding,
121
+ b,
122
+ l,
123
+ d,
124
+ k,
125
+ l_out
126
+ );
127
+ }
128
+ )
129
+ );
130
+
131
+ return out;
132
+ }
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_blh.cu ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+
3
+ // Simple 1D depthwise convolution implementation with dilation and stride = 1
4
+
5
+ #include "shared.h"
6
+
7
+ //For max perf, tune for your GPU and batch size, and datatype etc
8
+ const uint BX = 512;
9
+ const uint BY = 1;
10
+ const uint BZ = 1;
11
+
12
+ const uint TILE_SIZE_Y = 4;
13
+ const uint TILE_SIZE_X = 2;
14
+
15
+ // Trick to do padding in place without actually creating a new tensor
16
+ __forceinline__ __device__ __half2 get_u(const __half2 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
17
+ {
18
+ return l + k < p || l + k > L_eff - (p + 1) ? __float2half2_rn(0.0f) : u[b * L * D + (l + k - p) * D + d];
19
+ }
20
+
21
+
22
+ __forceinline__ __device__ __nv_bfloat162 get_u(const __nv_bfloat162 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
23
+ {
24
+ return l + k < p || l + k > L_eff - (p + 1) ? __float2bfloat162_rn(0.0f) : u[b * L * D + (l + k - p) * D + d];
25
+ }
26
+
27
+ __forceinline__ __device__ float2 get_u(const float2 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
28
+ {
29
+ return l + k < p || l + k > L_eff - (p + 1) ? make_float2(0.0f, 0.0f) : u[b * L * D + (l + k - p) * D + d];
30
+ }
31
+
32
+
33
+ //manually unrolling loop for k = 3 leads to good perf, can easily extend for other values of k if need be
34
+ template<typename T, typename U>
35
+ __forceinline__ __device__ T _conv1d_k_3(const T* u, const U* weights, const U* bias, T* out, uint padding, uint b, uint l, uint d, uint t, uint L, uint D, uint K, uint L_eff, uint L_out)
36
+ {
37
+
38
+ T tmp;
39
+ T weight;
40
+ set_value(&tmp, bias[d]);
41
+
42
+ set_value(&weight, weights[0 * D + d]);
43
+ tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, 0, d, L, D, K), weight, tmp);
44
+
45
+ set_value(&weight, weights[1 * D + d]);
46
+ tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, 1, d, L, D, K), weight, tmp);
47
+
48
+ set_value(&weight, weights[2 * D + d]);
49
+ out[b * D * L_out + (l + t) * D + d] = __hfma2(get_u(u, L_eff, l + t, padding, b, 2, d, L, D, K), weight, tmp);
50
+
51
+ }
52
+
53
+ template<typename T, typename U>
54
+ __global__ void conv1d_kernel_k_3(
55
+ const T *__restrict__ u,
56
+ const U *__restrict__ weights,
57
+ const U *__restrict__ bias,
58
+ T *__restrict__ out,
59
+ uint padding,
60
+ uint B,
61
+ uint L,
62
+ uint L_out,
63
+ uint L_eff,
64
+ uint D,
65
+ uint K)
66
+ {
67
+ const int d_block = blockIdx.x * blockDim.x * TILE_SIZE_X;
68
+ const int l = blockIdx.y * blockDim.y * TILE_SIZE_Y + threadIdx.y * TILE_SIZE_Y;
69
+ const int b = blockIdx.z * blockDim.z + threadIdx.z;
70
+
71
+ int d;
72
+
73
+ #pragma unroll
74
+ for (int i = 0; i < TILE_SIZE_X; i++)
75
+ {
76
+ d = d_block + threadIdx.x + i * BX;
77
+
78
+ if (d < D && b < B){
79
+ #pragma unroll
80
+ for (int t = 0; t < TILE_SIZE_Y; t++){
81
+ if (l + t < L_eff - K + 1)
82
+ {
83
+ _conv1d_k_3(u, weights, bias, out, padding, b, l, d, t, L, D, K, L_eff, L_out);
84
+ }
85
+ }
86
+ }
87
+ }
88
+ }
89
+
90
+ template<typename T, typename U>
91
+ __global__ void conv1d_kernel(
92
+ const T *__restrict__ u,
93
+ const U *__restrict__ weights,
94
+ const U *__restrict__ bias,
95
+ T *__restrict__ out,
96
+ uint padding,
97
+ uint B,
98
+ uint L,
99
+ uint L_out,
100
+ uint L_eff,
101
+ uint D,
102
+ uint K)
103
+ {
104
+ const int d_block = blockIdx.x * blockDim.x * TILE_SIZE_X;
105
+ const int l = blockIdx.y * blockDim.y * TILE_SIZE_Y + threadIdx.y * TILE_SIZE_Y;
106
+ const int b = blockIdx.z * blockDim.z + threadIdx.z;
107
+
108
+ int d;
109
+ T tmp;
110
+ T weight;
111
+
112
+ #pragma unroll
113
+ for (int i = 0; i < TILE_SIZE_X; i++)
114
+ {
115
+ d = d_block + threadIdx.x + i * BX;
116
+
117
+ if (d < D && b < B){
118
+ #pragma unroll
119
+ for (int t = 0; t < TILE_SIZE_Y; t++){
120
+ if (l + t < L_eff - K + 1)
121
+ {
122
+ set_value(&tmp, bias[d]);
123
+
124
+ for(int k = 0; k < K; k++){
125
+ set_value(&weight, weights[k * D + d]);
126
+
127
+ tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, k, d, L, D, K), weight, tmp);
128
+ }
129
+ out[b * D * L_out + (l + t) * D + d] = tmp;
130
+ }
131
+ }
132
+ }
133
+ }
134
+ }
135
+
136
+ torch::Tensor conv1d_cuda_blh(
137
+ torch::Tensor u,
138
+ torch::Tensor weight,
139
+ torch::Tensor bias,
140
+ uint padding)
141
+ {
142
+ const uint b = u.size(0);
143
+ const uint l = u.size(1);
144
+ const uint d = u.size(2);
145
+
146
+ const uint k = weight.size(0);
147
+
148
+ uint l_eff = l + 2 * padding;
149
+
150
+
151
+
152
+ dim3 blockDims(BX, BY, BZ);
153
+
154
+ dim3 gridDims(ceil(d * 1.0 / (BX * TILE_SIZE_X * 2) ), ceil((l_eff - k + 1) * 1.0 / (BY * TILE_SIZE_Y)), ceil(b * 1.0 / BZ));
155
+
156
+
157
+ uint l_out = (l + 2 * padding - k + 1);
158
+
159
+ torch::Tensor out = torch::empty({b, l_out, d}, u.options());
160
+
161
+ //calling seperate kernels for k=3 and k!=3 leads to better perf
162
+ if(k==3){
163
+ DISPATCH_FLOAT2_AND_HALF2_AND_BF162(u.scalar_type(), weight.scalar_type(),
164
+ "depthwise conv 1d fwd blh",
165
+ ([&]
166
+ { conv1d_kernel_k_3<input_t, weight_t><<<gridDims, blockDims>>>(
167
+ static_cast<input_t *>(u.data_ptr()),
168
+ static_cast<weight_t *>(weight.data_ptr()),
169
+ static_cast<weight_t *>(bias.data_ptr()),
170
+ static_cast<input_t *>(out.data_ptr()),
171
+ padding,
172
+ b,
173
+ l,
174
+ l_out,
175
+ l_eff,
176
+ ceil(d/2),
177
+ k);
178
+ }
179
+ )
180
+ );
181
+ }else{
182
+ DISPATCH_FLOAT2_AND_HALF2_AND_BF162(u.scalar_type(), weight.scalar_type(),
183
+ "depthwise conv 1d fwd blh",
184
+ ([&]
185
+ { conv1d_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
186
+ static_cast<input_t *>(u.data_ptr()),
187
+ static_cast<weight_t *>(weight.data_ptr()),
188
+ static_cast<weight_t *>(bias.data_ptr()),
189
+ static_cast<input_t *>(out.data_ptr()),
190
+ padding,
191
+ b,
192
+ l,
193
+ l_out,
194
+ l_eff,
195
+ ceil(d/2),
196
+ k);
197
+ }
198
+ )
199
+ );
200
+ }
201
+ return out;
202
+ }
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bwd_cuda_bhl.cu ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2023 Dan Fu, Hermann Kumbong
2
+ #include "shared.h"
3
+
4
+ const uint BX = 128;
5
+ const uint BY = 1;
6
+ const uint BZ = 1;
7
+
8
+ const uint TILE_SIZE = 4;
9
+
10
+ template <typename input_t, typename weight_t>
11
+ __global__ void conv1d_backward_kernel(
12
+ const input_t* __restrict__ dout,
13
+ const input_t* __restrict__ u,
14
+ const weight_t* __restrict__ weights,
15
+ input_t* __restrict__ du,
16
+ input_t* __restrict__ dk,
17
+ uint B,
18
+ uint L,
19
+ uint D,
20
+ uint K,
21
+ uint P
22
+ )
23
+ {
24
+ const int b = blockIdx.z;
25
+ const int d = blockIdx.y;
26
+ const int l = blockIdx.x;
27
+
28
+ //construct the du matrix
29
+ if(b < B && d < D && l == 0){
30
+ for(int j = threadIdx.x; j < L; j += blockDim.x)
31
+ {
32
+ input_t sum;
33
+ set_value(&sum, 0.0f);
34
+ input_t weight;
35
+
36
+ for(int k = 0; k < K ; k++)
37
+ {
38
+ int idx = - P + k + j;
39
+
40
+ if(idx >= 0 && idx < L){
41
+ set_value(&weight, weights[d * K + K - (k +1)]);
42
+ sum = __hfma(dout[b * D * L + d * L + idx], weight, sum);
43
+ }
44
+ }
45
+ du[b * D * L + d * L + j] = sum;
46
+ }
47
+ }
48
+
49
+ const int k = blockIdx.x;
50
+ input_t tmp;
51
+ //construct the dk matrix
52
+ if(b < B && d < D && k < K)
53
+ {
54
+ for(int j = threadIdx.x; j < L; j += blockDim.x)
55
+ {
56
+ if(k - P + j < 0 || k - P + j >= L){
57
+ set_value(&dk[b * D * K * L + d * K * L + k * L + j], 0.0f);
58
+
59
+ }else{
60
+ set_value(&dk[b * D * K * L + d * K * L + k * L + j], u[b * D * L + d * L + k - P + j]);
61
+ }
62
+ }
63
+ }
64
+
65
+ }
66
+
67
+ std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
68
+ torch::Tensor dout,
69
+ torch::Tensor u,
70
+ torch::Tensor weight,
71
+ torch::Tensor bias,
72
+ uint padding)
73
+ {
74
+ const uint b = u.size(0);
75
+ const uint d = u.size(1);
76
+ const uint l = u.size(2);
77
+
78
+ const uint k = weight.squeeze().size(1);
79
+
80
+ dim3 blockDims(BX, 1, 1);
81
+
82
+ dim3 gridDims(l, d, b);
83
+
84
+ torch::Tensor du = torch::empty({b, d, l}, u.options());
85
+ torch::Tensor dk = torch::empty({b, d, k, l}, dout.options());
86
+ torch::Tensor dbias = dout.sum(-1).sum(0);
87
+
88
+ DISPATCH_FLOAT_AND_HALF_AND_BF16(dout.scalar_type(), weight.scalar_type(),
89
+ "depthwise conv 1d backward bhl",
90
+ ([&]
91
+ { conv1d_backward_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
92
+ static_cast<input_t *>(dout.data_ptr()),
93
+ static_cast<input_t *>(u.data_ptr()),
94
+ static_cast<weight_t *>(weight.data_ptr()),
95
+ static_cast<input_t *>(du.data_ptr()),
96
+ static_cast<input_t *>(dk.data_ptr()),
97
+ b,
98
+ l,
99
+ d,
100
+ k,
101
+ padding);
102
+ }
103
+ )
104
+ );
105
+ return {du, torch::matmul(dk, dout.unsqueeze(-1)).squeeze(-1).sum(0).to(weight.type()), dbias};
106
+ }