Update Feather a10g-large training runtime image
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .dockerignore +16 -0
- .guardian_trigger_20260512_211050 +1 -0
- .rebuild_sentry +1 -0
- FORCE_REBUILD +3 -0
- README.md +6 -5
- REBUILD_FLAG_1778645488 +0 -0
- entrypoint.py +1 -20
- overlay/.dockerignore +20 -0
- overlay/BUILD_STAMP +1 -0
- overlay/harness/benchmark_validity.py +210 -0
- overlay/harness/tps_manifest_validity.py +209 -0
- overlay/htm_rust/.cargo/config.toml +2 -0
- overlay/htm_rust/.claude/CLAUDE.md +0 -0
- overlay/htm_rust/.letta/claude/conversations.json +6 -0
- overlay/htm_rust/.letta/claude/session-c892b9c9-7fe5-4f14-8157-ec8740e965d1.json +0 -0
- overlay/htm_rust/Cargo.lock +42 -0
- overlay/htm_rust/Cargo.toml +3 -1
- overlay/htm_rust/DLB_PERKS_IMPLEMENTATION_PLAN.md +194 -0
- overlay/htm_rust/bench_gpu.py +81 -0
- overlay/htm_rust/docs/GPU_HTM.md +302 -0
- overlay/htm_rust/src/gpu/fused.rs +58 -10
- overlay/htm_rust/src/gpu/mod.rs +134 -1
- overlay/htm_rust/src/lib.rs +27 -0
- overlay/htm_rust/src/region.rs +2 -0
- overlay/htm_rust/src/sp.rs +5 -1
- overlay/htm_rust/src/tm.rs +6 -2
- overlay/htm_rust/uv.lock +8 -0
- overlay/hydra/model.py +96 -8
- overlay/hydra/optimizer.py +118 -44
- overlay/hydra/training.py +66 -25
- overlay/kernels/__init__.py +0 -0
- overlay/kernels/cuda/decode_kernels.cu +10 -0
- overlay/kernels/cuda/flashfftconv/LICENSE +201 -0
- overlay/kernels/cuda/flashfftconv/README.md +57 -0
- overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT +1 -0
- overlay/kernels/cuda/flashfftconv/csrc/.gitignore +10 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h +374 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu +699 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu +725 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu +723 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu +705 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda.cu +871 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda_bf16.cu +897 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda.cu +905 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda_bf16.cu +917 -0
- overlay/kernels/cuda/flashfftconv/csrc/butterfly/shared.h +60 -0
- overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d.h +96 -0
- overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bhl.cu +132 -0
- overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_blh.cu +202 -0
- overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bwd_cuda_bhl.cu +106 -0
.dockerignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Keep HF runtime image context deterministic and small.
|
| 2 |
+
**/__pycache__/
|
| 3 |
+
**/*.py[cod]
|
| 4 |
+
**/.pytest_cache/
|
| 5 |
+
**/.mypy_cache/
|
| 6 |
+
**/.ruff_cache/
|
| 7 |
+
**/.venv/
|
| 8 |
+
**/target/
|
| 9 |
+
**/logs/
|
| 10 |
+
**/*.log
|
| 11 |
+
**/*.out
|
| 12 |
+
**/*.pt
|
| 13 |
+
**/*.safetensors
|
| 14 |
+
**/*.parquet
|
| 15 |
+
**/*.npz
|
| 16 |
+
**/.git/
|
.guardian_trigger_20260512_211050
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Guardian forced rebuild at 2026-05-12T21:10:50.366196
|
.rebuild_sentry
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
FORCE_REBUILD_e9883655-cf86-4724-84bd-68740a3feefb
|
FORCE_REBUILD
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FORCE_SPACE_REBUILD=$(date -u +%s)
|
| 2 |
+
# This flag forces the Space image to rebuild with the latest overlay code
|
| 3 |
+
# containing the retina_contrastive fix
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title: Feather
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Feather H200 Runtime Slim
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
Feather runtime image used as a Docker Space source for Hugging Face Jobs.
|
REBUILD_FLAG_1778645488
ADDED
|
File without changes
|
entrypoint.py
CHANGED
|
@@ -217,25 +217,6 @@ def _run_training_subprocess(cmd: list[str]) -> int:
|
|
| 217 |
def run_job_mode() -> int:
|
| 218 |
os.chdir(REPO_ROOT)
|
| 219 |
|
| 220 |
-
# Guardian: force contrastive_rank=0 and disk-patch sdr_semantic.py
|
| 221 |
-
os.environ["HYDRA_CONTRASTIVE_RANK"] = "0"
|
| 222 |
-
_sdr_path = REPO_ROOT / 'subsystems' / 'sdr_semantic.py'
|
| 223 |
-
if _sdr_path.exists():
|
| 224 |
-
_text = _sdr_path.read_text()
|
| 225 |
-
if 'retina_contrastive' not in _text:
|
| 226 |
-
print('[guardian] patching sdr_semantic.py on disk ...', flush=True)
|
| 227 |
-
_text = _text.replace(
|
| 228 |
-
'super().__init__()\n' +
|
| 229 |
-
' # Audit 2026-05-13: allow disabling',
|
| 230 |
-
'super().__init__()\n' +
|
| 231 |
-
' self.retina_contrastive = None # guardian patch\n' +
|
| 232 |
-
' # Audit 2026-05-13: allow disabling',
|
| 233 |
-
)
|
| 234 |
-
_sdr_path.write_text(_text)
|
| 235 |
-
print('[guardian] patched sdr_semantic.py on disk', flush=True)
|
| 236 |
-
print('[guardian] HYDRA_CONTRASTIVE_RANK=0 enforced for checkpoint compat', flush=True)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
# Dynamic live patch from GitHub to bypass Space build errors
|
| 240 |
GIT_REF = os.environ.get('FEATHER_GIT_REF')
|
| 241 |
if GIT_REF:
|
|
@@ -307,4 +288,4 @@ def main() -> int:
|
|
| 307 |
|
| 308 |
|
| 309 |
if __name__ == '__main__':
|
| 310 |
-
raise SystemExit(main())
|
|
|
|
| 217 |
def run_job_mode() -> int:
|
| 218 |
os.chdir(REPO_ROOT)
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
# Dynamic live patch from GitHub to bypass Space build errors
|
| 221 |
GIT_REF = os.environ.get('FEATHER_GIT_REF')
|
| 222 |
if GIT_REF:
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
if __name__ == '__main__':
|
| 291 |
+
raise SystemExit(main())
|
overlay/.dockerignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.venv
|
| 4 |
+
.remember
|
| 5 |
+
.letta
|
| 6 |
+
.claude
|
| 7 |
+
__pycache__
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
*.log
|
| 12 |
+
run_*.log
|
| 13 |
+
run*.log
|
| 14 |
+
*.txt
|
| 15 |
+
WORKER_COMPLETE
|
| 16 |
+
autoresearch_loop.log
|
| 17 |
+
data/
|
| 18 |
+
state_store/
|
| 19 |
+
htm_rust/target/
|
| 20 |
+
hydra-core/target/
|
overlay/BUILD_STAMP
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1778646814_120314
|
overlay/harness/benchmark_validity.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark validity and comparable-group helpers for HYDRA scorecards.
|
| 2 |
+
|
| 3 |
+
This module deliberately separates benchmark validity from model quality. A run
|
| 4 |
+
can be useful diagnostic evidence while still being invalid for promotion if its
|
| 5 |
+
corpus or eval protocol differs from the baseline.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
PUBLIC_FULL_BLEND_ID = "public_full_blend_v0"
|
| 16 |
+
PUBLIC_FULL_BLEND_WEIGHTS = {
|
| 17 |
+
"fineweb-edu": 0.55,
|
| 18 |
+
"wikipedia": 0.25,
|
| 19 |
+
"cosmopedia": 0.15,
|
| 20 |
+
"fineweb": 0.05,
|
| 21 |
+
}
|
| 22 |
+
GATED_OR_PRIVATE_MARKERS = (
|
| 23 |
+
"stack-v2",
|
| 24 |
+
"nemotron-math",
|
| 25 |
+
"nemotron-specialized",
|
| 26 |
+
"nvidia/nemotron",
|
| 27 |
+
"Nemotron-CC-Math",
|
| 28 |
+
"Nemotron-Pretraining-Specialized",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _text_blob(row: dict[str, Any]) -> str:
|
| 33 |
+
return json.dumps(row, sort_keys=True, default=str)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _ablation(row: dict[str, Any]) -> dict[str, Any]:
|
| 37 |
+
ablation = row.get("ablation")
|
| 38 |
+
return ablation if isinstance(ablation, dict) else {}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _has_public_full_blend(row: dict[str, Any]) -> bool:
|
| 42 |
+
ablation = _ablation(row)
|
| 43 |
+
corpus_profile = str(row.get("corpus_profile") or "").lower()
|
| 44 |
+
corpus_standard = str(ablation.get("corpus_standard") or row.get("corpus_standard") or "").lower()
|
| 45 |
+
notes = str(row.get("notes") or "").lower()
|
| 46 |
+
blend_weights = row.get("full_blend_weights")
|
| 47 |
+
single_config = str(
|
| 48 |
+
ablation.get("HYDRA_NEMOTRON_SINGLE_CONFIG")
|
| 49 |
+
or row.get("HYDRA_NEMOTRON_SINGLE_CONFIG")
|
| 50 |
+
or ""
|
| 51 |
+
).strip().lower()
|
| 52 |
+
|
| 53 |
+
has_full_blend_marker = (
|
| 54 |
+
row.get("HYDRA_USE_FULL_BLEND") == "1"
|
| 55 |
+
or row.get("HYDRA_USE_FULL_BLEND") == 1
|
| 56 |
+
or row.get("HYDRA_USE_FULL_BLEND") is True
|
| 57 |
+
or "hydra_use_full_blend=1" in corpus_standard
|
| 58 |
+
or corpus_profile == PUBLIC_FULL_BLEND_ID
|
| 59 |
+
or blend_weights == PUBLIC_FULL_BLEND_WEIGHTS
|
| 60 |
+
or "public benchmark blend" in corpus_standard
|
| 61 |
+
or "public full-blend" in notes
|
| 62 |
+
or "full-blend eval settings" in notes
|
| 63 |
+
)
|
| 64 |
+
single_config_is_blank = single_config in {"", "<unset>", "none", "null"}
|
| 65 |
+
return bool(has_full_blend_marker and single_config_is_blank)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _uses_private_or_gated_corpus(row: dict[str, Any]) -> bool:
|
| 69 |
+
blob = _text_blob(row).lower()
|
| 70 |
+
return any(marker.lower() in blob for marker in GATED_OR_PRIVATE_MARKERS)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _eval_tokens(row: dict[str, Any]) -> int | None:
|
| 74 |
+
raw = row.get("eval_tokens")
|
| 75 |
+
if raw in (None, ""):
|
| 76 |
+
return None
|
| 77 |
+
try:
|
| 78 |
+
return int(raw)
|
| 79 |
+
except (TypeError, ValueError):
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _eval_batch(row: dict[str, Any]) -> int | None:
|
| 84 |
+
raw = row.get("eval_batch", 1)
|
| 85 |
+
if raw in (None, ""):
|
| 86 |
+
return None
|
| 87 |
+
try:
|
| 88 |
+
return int(raw)
|
| 89 |
+
except (TypeError, ValueError):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _eval_protocol(row: dict[str, Any]) -> str:
|
| 94 |
+
val_source = str(row.get("val_source") or "").lower()
|
| 95 |
+
row_type = str(row.get("type") or "").lower()
|
| 96 |
+
if "fresh_checkpoint_eval" in val_source or "fresh_checkpoint_eval" in row_type:
|
| 97 |
+
return "fresh_checkpoint_eval"
|
| 98 |
+
if "in_process" in val_source or "in_process" in row_type:
|
| 99 |
+
return "in_process_eval"
|
| 100 |
+
return val_source or row_type or "unknown_eval"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _gpu_flavor(row: dict[str, Any]) -> str:
|
| 104 |
+
return str(row.get("gpu_flavor") or row.get("FEATHER_HF_FLAVOR") or "a10g-large").lower()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _runtime_profile(row: dict[str, Any]) -> str:
|
| 108 |
+
return str(
|
| 109 |
+
row.get("runtime_profile")
|
| 110 |
+
or row.get("FEATHER_HF_RUNTIME_PROFILE")
|
| 111 |
+
or "a10-compromise-telemetry"
|
| 112 |
+
).lower()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def benchmark_invalid_reason(row: dict[str, Any]) -> str:
|
| 116 |
+
"""Return an empty string when a row is benchmark-valid."""
|
| 117 |
+
if row.get("crashed") is True:
|
| 118 |
+
return "run crashed"
|
| 119 |
+
if row.get("metrics_write_failed") is True and row.get("val_bpb") in (None, 0, 0.0):
|
| 120 |
+
return "metrics missing or failed"
|
| 121 |
+
val_bpb = row.get("val_bpb")
|
| 122 |
+
try:
|
| 123 |
+
if val_bpb is None or float(val_bpb) <= 0:
|
| 124 |
+
return "missing positive val_bpb"
|
| 125 |
+
except (TypeError, ValueError):
|
| 126 |
+
return "missing positive val_bpb"
|
| 127 |
+
if not _has_public_full_blend(row):
|
| 128 |
+
return "not public full blend / full blend invariant missing"
|
| 129 |
+
if _uses_private_or_gated_corpus(row):
|
| 130 |
+
return "uses private/gated corpus marker"
|
| 131 |
+
if _eval_tokens(row) is None:
|
| 132 |
+
return "missing eval_tokens"
|
| 133 |
+
if _eval_batch(row) is None:
|
| 134 |
+
return "missing eval_batch"
|
| 135 |
+
if _eval_protocol(row) != "fresh_checkpoint_eval":
|
| 136 |
+
return "not fresh checkpoint eval"
|
| 137 |
+
return ""
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def comparable_group_id(row: dict[str, Any]) -> str:
|
| 141 |
+
"""Build a stable comparable-group identifier from protocol fields only.
|
| 142 |
+
|
| 143 |
+
Deliberately excludes checkpoint/model/ablation identities so architecture
|
| 144 |
+
variants can be compared when corpus and eval protocol match.
|
| 145 |
+
"""
|
| 146 |
+
parts = {
|
| 147 |
+
"corpus": PUBLIC_FULL_BLEND_ID if _has_public_full_blend(row) else "non_public_or_unknown_corpus",
|
| 148 |
+
"eval_protocol": _eval_protocol(row),
|
| 149 |
+
"eval_tokens": _eval_tokens(row),
|
| 150 |
+
"eval_batch": _eval_batch(row),
|
| 151 |
+
"gpu_flavor": _gpu_flavor(row),
|
| 152 |
+
"runtime_profile": _runtime_profile(row),
|
| 153 |
+
}
|
| 154 |
+
digest = hashlib.sha1(json.dumps(parts, sort_keys=True).encode()).hexdigest()[:10]
|
| 155 |
+
return "cmp_" + digest
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def normalize_scorecard_row(row: dict[str, Any]) -> dict[str, Any]:
|
| 159 |
+
"""Return a row copy annotated with v0 benchmark validity metadata."""
|
| 160 |
+
normalized = deepcopy(row)
|
| 161 |
+
invalid_reason = benchmark_invalid_reason(normalized)
|
| 162 |
+
normalized["benchmark_valid"] = not invalid_reason
|
| 163 |
+
normalized["benchmark_status"] = "comparable" if not invalid_reason else "diagnostic"
|
| 164 |
+
normalized["invalid_reason"] = invalid_reason
|
| 165 |
+
normalized["corpus_profile"] = PUBLIC_FULL_BLEND_ID if _has_public_full_blend(normalized) else "non_public_or_unknown"
|
| 166 |
+
normalized["full_blend_weights"] = PUBLIC_FULL_BLEND_WEIGHTS if _has_public_full_blend(normalized) else None
|
| 167 |
+
normalized["eval_tokens"] = _eval_tokens(normalized)
|
| 168 |
+
normalized["eval_batch"] = _eval_batch(normalized)
|
| 169 |
+
normalized["eval_protocol"] = _eval_protocol(normalized)
|
| 170 |
+
normalized["gpu_flavor"] = _gpu_flavor(normalized)
|
| 171 |
+
normalized["runtime_profile"] = _runtime_profile(normalized)
|
| 172 |
+
normalized["comparable_group_id"] = comparable_group_id(normalized)
|
| 173 |
+
return normalized
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def are_comparable(left: dict[str, Any], right: dict[str, Any]) -> bool:
|
| 177 |
+
left_n = normalize_scorecard_row(left)
|
| 178 |
+
right_n = normalize_scorecard_row(right)
|
| 179 |
+
return bool(
|
| 180 |
+
left_n["benchmark_valid"]
|
| 181 |
+
and right_n["benchmark_valid"]
|
| 182 |
+
and left_n["comparable_group_id"] == right_n["comparable_group_id"]
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def compare_candidate(candidate: dict[str, Any], baseline: dict[str, Any]) -> dict[str, Any]:
|
| 187 |
+
"""Compare two scorecard rows with validity-first promotion semantics."""
|
| 188 |
+
candidate_n = normalize_scorecard_row(candidate)
|
| 189 |
+
baseline_n = normalize_scorecard_row(baseline)
|
| 190 |
+
if not candidate_n["benchmark_valid"]:
|
| 191 |
+
return {"decision": "invalid_candidate", "reason": candidate_n["invalid_reason"]}
|
| 192 |
+
if not baseline_n["benchmark_valid"]:
|
| 193 |
+
return {"decision": "invalid_baseline", "reason": baseline_n["invalid_reason"]}
|
| 194 |
+
if candidate_n["comparable_group_id"] != baseline_n["comparable_group_id"]:
|
| 195 |
+
return {
|
| 196 |
+
"decision": "not_comparable",
|
| 197 |
+
"reason": (
|
| 198 |
+
"comparable_group_id mismatch: "
|
| 199 |
+
f"candidate={candidate_n['comparable_group_id']} "
|
| 200 |
+
f"baseline={baseline_n['comparable_group_id']}"
|
| 201 |
+
),
|
| 202 |
+
}
|
| 203 |
+
delta_bpb = float(candidate_n["val_bpb"]) - float(baseline_n["val_bpb"])
|
| 204 |
+
if delta_bpb < 0:
|
| 205 |
+
decision = "promote_candidate"
|
| 206 |
+
elif delta_bpb > 0:
|
| 207 |
+
decision = "keep_baseline"
|
| 208 |
+
else:
|
| 209 |
+
decision = "tie_requires_replication"
|
| 210 |
+
return {"decision": decision, "delta_bpb": delta_bpb, "reason": "same comparable_group_id"}
|
overlay/harness/tps_manifest_validity.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TPS/profiling manifest validity helpers for Feather kernel-fusion sweeps.
|
| 2 |
+
|
| 3 |
+
This module is the TPS-side sibling of ``harness.benchmark_validity``. It does
|
| 4 |
+
not decide model quality; it decides whether a row is valid evidence for max-TPS
|
| 5 |
+
promotion versus attribution/diagnostic evidence. The rules are intentionally
|
| 6 |
+
conservative because profiling flags and CPU fallbacks can make fast-looking rows
|
| 7 |
+
incomparable or unfaithful.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
A10_FLAVORS = {"a10g-small", "a10g-large", "a10g-largex2", "a10g-largex4"}
|
| 17 |
+
PROFILE_TRUE = {"1", "true", "yes", "on"}
|
| 18 |
+
PROFILE_FALSE = {"0", "false", "no", "off", ""}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _as_bool(value: Any, *, default: bool = False) -> bool:
|
| 22 |
+
if isinstance(value, bool):
|
| 23 |
+
return value
|
| 24 |
+
if value is None:
|
| 25 |
+
return default
|
| 26 |
+
text = str(value).strip().lower()
|
| 27 |
+
if text in PROFILE_TRUE:
|
| 28 |
+
return True
|
| 29 |
+
if text in PROFILE_FALSE:
|
| 30 |
+
return False
|
| 31 |
+
return default
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _int_or_none(value: Any) -> int | None:
|
| 35 |
+
if value in (None, ""):
|
| 36 |
+
return None
|
| 37 |
+
try:
|
| 38 |
+
return int(value)
|
| 39 |
+
except (TypeError, ValueError):
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _float_or_none(value: Any) -> float | None:
|
| 44 |
+
if value in (None, ""):
|
| 45 |
+
return None
|
| 46 |
+
try:
|
| 47 |
+
return float(value)
|
| 48 |
+
except (TypeError, ValueError):
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _nested(row: dict[str, Any], key: str) -> dict[str, Any]:
|
| 53 |
+
value = row.get(key)
|
| 54 |
+
return value if isinstance(value, dict) else {}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _env(row: dict[str, Any]) -> dict[str, Any]:
|
| 58 |
+
return _nested(row, "env")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _receipts(row: dict[str, Any]) -> dict[str, Any]:
|
| 62 |
+
return _nested(row, "receipts") or _nested(row, "receipts_required")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _hardware(row: dict[str, Any]) -> dict[str, Any]:
|
| 66 |
+
return _nested(row, "hardware")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _profile_forward_enabled(row: dict[str, Any]) -> bool:
|
| 70 |
+
env = _env(row)
|
| 71 |
+
receipts = _receipts(row)
|
| 72 |
+
if "profile_forward" in receipts:
|
| 73 |
+
return _as_bool(receipts.get("profile_forward"))
|
| 74 |
+
return _as_bool(env.get("HYDRA_PROFILE_FORWARD"))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _tps_window(row: dict[str, Any]) -> dict[str, Any]:
|
| 78 |
+
receipts = _receipts(row)
|
| 79 |
+
window = receipts.get("training_tps_window") or row.get("training_tps_window") or row.get("tps_window")
|
| 80 |
+
return window if isinstance(window, dict) else {}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _median_tps(row: dict[str, Any]) -> float | None:
|
| 84 |
+
window = _tps_window(row)
|
| 85 |
+
return _float_or_none(window.get("median") or row.get("median_tps") or row.get("tps"))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _flavor(row: dict[str, Any]) -> str:
|
| 89 |
+
hardware = _hardware(row)
|
| 90 |
+
receipts = _receipts(row)
|
| 91 |
+
return str(
|
| 92 |
+
hardware.get("flavor")
|
| 93 |
+
or receipts.get("flavor_verified")
|
| 94 |
+
or row.get("gpu_flavor")
|
| 95 |
+
or row.get("FEATHER_HF_FLAVOR")
|
| 96 |
+
or ""
|
| 97 |
+
).strip().lower()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _duplicate_count(row: dict[str, Any]) -> int | None:
|
| 101 |
+
check = row.get("duplicate_active_job_check")
|
| 102 |
+
if not isinstance(check, dict):
|
| 103 |
+
return None
|
| 104 |
+
return _int_or_none(check.get("active_matching_jobs"))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _scale_free_a10g_invalid_reasons(row: dict[str, Any]) -> list[str]:
|
| 108 |
+
"""Return fail-closed reasons for bounded A10G scale-free HTM proof rows."""
|
| 109 |
+
env = _env(row)
|
| 110 |
+
reasons: list[str] = []
|
| 111 |
+
if _flavor(row) not in A10_FLAVORS:
|
| 112 |
+
return reasons
|
| 113 |
+
proof_requested = (
|
| 114 |
+
_as_bool(env.get("HYDRA_HTM_STRICT_SCALE_FREE"), default=False)
|
| 115 |
+
or str(row.get("runtime_profile") or "").strip().lower() in {"optimal-strict", "a10g-scale-free-proof"}
|
| 116 |
+
)
|
| 117 |
+
if not proof_requested:
|
| 118 |
+
return reasons
|
| 119 |
+
|
| 120 |
+
if env.get("HYDRA_TARGET_SHARDS") not in {"0", 0}:
|
| 121 |
+
reasons.append("scale-free A10G proof requires HYDRA_TARGET_SHARDS=0")
|
| 122 |
+
if env.get("HYDRA_HTM_STRICT_SCALE_FREE") != "1":
|
| 123 |
+
reasons.append("scale-free A10G proof requires HYDRA_HTM_STRICT_SCALE_FREE=1")
|
| 124 |
+
region_pool = _int_or_none(env.get("HYDRA_HTM_REGION_POOL_SIZE"))
|
| 125 |
+
chunk_b = _int_or_none(env.get("HYDRA_HTM_CHUNK_B"))
|
| 126 |
+
if region_pool is None:
|
| 127 |
+
reasons.append("scale-free A10G proof requires HYDRA_HTM_REGION_POOL_SIZE")
|
| 128 |
+
elif region_pool > 4:
|
| 129 |
+
reasons.append("scale-free A10G proof requires HYDRA_HTM_REGION_POOL_SIZE<=4")
|
| 130 |
+
if chunk_b is None:
|
| 131 |
+
reasons.append("scale-free A10G proof requires HYDRA_HTM_CHUNK_B")
|
| 132 |
+
elif region_pool is not None and chunk_b > region_pool:
|
| 133 |
+
reasons.append("scale-free A10G proof requires HYDRA_HTM_CHUNK_B<=HYDRA_HTM_REGION_POOL_SIZE")
|
| 134 |
+
if env.get("HYDRA_TOKEN_CACHE_GB") not in {"0", 0}:
|
| 135 |
+
reasons.append("scale-free A10G proof requires HYDRA_TOKEN_CACHE_GB=0")
|
| 136 |
+
if env.get("HYDRA_DISABLE_TOKEN_CACHE") != "1":
|
| 137 |
+
reasons.append("scale-free A10G proof requires HYDRA_DISABLE_TOKEN_CACHE=1")
|
| 138 |
+
for key in (
|
| 139 |
+
"HYDRA_HTM_REGION_POOL_SIZE_FROM_VRAM",
|
| 140 |
+
"HYDRA_HTM_SCALE_TO_VRAM",
|
| 141 |
+
"HYDRA_VRAM_TOPOLOGY_SCALE",
|
| 142 |
+
"FEATHER_VRAM_TOPOLOGY_SCALE",
|
| 143 |
+
):
|
| 144 |
+
if _as_bool(env.get(key), default=False):
|
| 145 |
+
reasons.append(f"scale-free A10G proof forbids VRAM-derived topology scaling: {key}")
|
| 146 |
+
return reasons
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def tps_manifest_invalid_reasons(row: dict[str, Any]) -> list[str]:
|
| 150 |
+
"""Return all reasons a row cannot be used as max-TPS promotion evidence."""
|
| 151 |
+
reasons: list[str] = []
|
| 152 |
+
env = _env(row)
|
| 153 |
+
receipts = _receipts(row)
|
| 154 |
+
flavor = _flavor(row)
|
| 155 |
+
|
| 156 |
+
if row.get("crashed") is True:
|
| 157 |
+
reasons.append("run crashed")
|
| 158 |
+
if flavor not in A10_FLAVORS:
|
| 159 |
+
reasons.append(f"not A10G flavor: {flavor or 'missing'}")
|
| 160 |
+
if _profile_forward_enabled(row):
|
| 161 |
+
reasons.append("profile_forward enabled; attribution-only overhead row")
|
| 162 |
+
if _median_tps(row) is None:
|
| 163 |
+
reasons.append("missing training TPS window median")
|
| 164 |
+
duplicate_count = _duplicate_count(row)
|
| 165 |
+
if duplicate_count is None:
|
| 166 |
+
reasons.append("duplicate active job check missing")
|
| 167 |
+
elif duplicate_count > 0:
|
| 168 |
+
reasons.append(f"duplicate active Feather A10G jobs present: {duplicate_count}")
|
| 169 |
+
|
| 170 |
+
faithful_profile = "faithful" in str(row.get("runtime_profile") or "").lower()
|
| 171 |
+
htm_gpu_verified = _as_bool(receipts.get("htm_gpu_verified"), default=False)
|
| 172 |
+
force_htm_cpu = _as_bool(env.get("HYDRA_FORCE_HTM_CPU"), default=False)
|
| 173 |
+
if faithful_profile and (force_htm_cpu or not htm_gpu_verified):
|
| 174 |
+
reasons.append("faithful row lacks HTM GPU verification or uses CPU fallback")
|
| 175 |
+
if faithful_profile and env.get("HYDRA_HTM_FUSED") != "1":
|
| 176 |
+
reasons.append("faithful row missing HYDRA_HTM_FUSED=1")
|
| 177 |
+
if faithful_profile and env.get("HYDRA_HTM_BATCHED_FUSED") != "1":
|
| 178 |
+
reasons.append("faithful row missing HYDRA_HTM_BATCHED_FUSED=1")
|
| 179 |
+
if _as_bool(env.get("HYDRA_USE_NEMOTRON"), default=False) and env.get("HYDRA_TARGET_SHARDS") not in {"0", 0}:
|
| 180 |
+
reasons.append("Nemotron streaming TPS row must use HYDRA_TARGET_SHARDS=0")
|
| 181 |
+
if env.get("HYDRA_TOKEN_CACHE_GB") not in {"0", 0, None}:
|
| 182 |
+
reasons.append("token cache enabled/materializing during TPS row")
|
| 183 |
+
reasons.extend(_scale_free_a10g_invalid_reasons(row))
|
| 184 |
+
return reasons
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def tps_manifest_invalid_reason(row: dict[str, Any]) -> str:
|
| 188 |
+
return "; ".join(tps_manifest_invalid_reasons(row))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def normalize_tps_manifest(row: dict[str, Any]) -> dict[str, Any]:
|
| 192 |
+
"""Return a copy annotated with TPS/profiling validity metadata."""
|
| 193 |
+
normalized = deepcopy(row)
|
| 194 |
+
reasons = tps_manifest_invalid_reasons(normalized)
|
| 195 |
+
profile_forward = _profile_forward_enabled(normalized)
|
| 196 |
+
normalized["tps_valid"] = not reasons
|
| 197 |
+
if not reasons:
|
| 198 |
+
status = "promotion_candidate"
|
| 199 |
+
elif profile_forward or str(normalized.get("metric_role") or "").lower() == "profile":
|
| 200 |
+
status = "attribution_only"
|
| 201 |
+
else:
|
| 202 |
+
status = "diagnostic"
|
| 203 |
+
normalized["tps_status"] = status
|
| 204 |
+
normalized["invalid_reason"] = "; ".join(reasons)
|
| 205 |
+
normalized["gpu_flavor"] = _flavor(normalized)
|
| 206 |
+
normalized["median_tps"] = _median_tps(normalized)
|
| 207 |
+
normalized["profile_forward"] = profile_forward
|
| 208 |
+
normalized["duplicate_active_job_count"] = _duplicate_count(normalized)
|
| 209 |
+
return normalized
|
overlay/htm_rust/.cargo/config.toml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[target.x86_64-unknown-linux-gnu]
|
| 2 |
+
linker = "/usr/bin/cc"
|
overlay/htm_rust/.claude/CLAUDE.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
overlay/htm_rust/.letta/claude/conversations.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"c892b9c9-7fe5-4f14-8157-ec8740e965d1": {
|
| 3 |
+
"conversationId": "conv-b42ddc79-3745-4edf-b165-4281a8961d3b",
|
| 4 |
+
"agentId": "agent-2cc00bdf-45f5-4725-bb56-7b4ab142153e"
|
| 5 |
+
}
|
| 6 |
+
}
|
overlay/htm_rust/.letta/claude/session-c892b9c9-7fe5-4f14-8157-ec8740e965d1.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
overlay/htm_rust/Cargo.lock
CHANGED
|
@@ -8,6 +8,15 @@ version = "1.5.0"
|
|
| 8 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
[[package]]
|
| 12 |
name = "cfg-if"
|
| 13 |
version = "1.0.4"
|
|
@@ -44,12 +53,14 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
|
| 44 |
name = "htm_rust"
|
| 45 |
version = "0.1.0"
|
| 46 |
dependencies = [
|
|
|
|
| 47 |
"cudarc",
|
| 48 |
"ndarray",
|
| 49 |
"numpy",
|
| 50 |
"pyo3",
|
| 51 |
"rand",
|
| 52 |
"rand_xoshiro",
|
|
|
|
| 53 |
]
|
| 54 |
|
| 55 |
[[package]]
|
|
@@ -301,6 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
| 301 |
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
| 302 |
dependencies = [
|
| 303 |
"rand_core",
|
|
|
|
| 304 |
]
|
| 305 |
|
| 306 |
[[package]]
|
|
@@ -321,6 +333,36 @@ version = "1.0.22"
|
|
| 321 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 322 |
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
[[package]]
|
| 325 |
name = "syn"
|
| 326 |
version = "2.0.117"
|
|
|
|
| 8 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 9 |
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
| 10 |
|
| 11 |
+
[[package]]
|
| 12 |
+
name = "bincode"
|
| 13 |
+
version = "1.3.3"
|
| 14 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 15 |
+
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
| 16 |
+
dependencies = [
|
| 17 |
+
"serde",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
[[package]]
|
| 21 |
name = "cfg-if"
|
| 22 |
version = "1.0.4"
|
|
|
|
| 53 |
name = "htm_rust"
|
| 54 |
version = "0.1.0"
|
| 55 |
dependencies = [
|
| 56 |
+
"bincode",
|
| 57 |
"cudarc",
|
| 58 |
"ndarray",
|
| 59 |
"numpy",
|
| 60 |
"pyo3",
|
| 61 |
"rand",
|
| 62 |
"rand_xoshiro",
|
| 63 |
+
"serde",
|
| 64 |
]
|
| 65 |
|
| 66 |
[[package]]
|
|
|
|
| 312 |
checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa"
|
| 313 |
dependencies = [
|
| 314 |
"rand_core",
|
| 315 |
+
"serde",
|
| 316 |
]
|
| 317 |
|
| 318 |
[[package]]
|
|
|
|
| 333 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 334 |
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
|
| 335 |
|
| 336 |
+
[[package]]
|
| 337 |
+
name = "serde"
|
| 338 |
+
version = "1.0.228"
|
| 339 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 340 |
+
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
| 341 |
+
dependencies = [
|
| 342 |
+
"serde_core",
|
| 343 |
+
"serde_derive",
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
[[package]]
|
| 347 |
+
name = "serde_core"
|
| 348 |
+
version = "1.0.228"
|
| 349 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 350 |
+
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
| 351 |
+
dependencies = [
|
| 352 |
+
"serde_derive",
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
[[package]]
|
| 356 |
+
name = "serde_derive"
|
| 357 |
+
version = "1.0.228"
|
| 358 |
+
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 359 |
+
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
| 360 |
+
dependencies = [
|
| 361 |
+
"proc-macro2",
|
| 362 |
+
"quote",
|
| 363 |
+
"syn",
|
| 364 |
+
]
|
| 365 |
+
|
| 366 |
[[package]]
|
| 367 |
name = "syn"
|
| 368 |
version = "2.0.117"
|
overlay/htm_rust/Cargo.toml
CHANGED
|
@@ -15,7 +15,9 @@ pyo3 = { version = "0.22", features = ["extension-module"] }
|
|
| 15 |
numpy = "0.22"
|
| 16 |
ndarray = "0.16"
|
| 17 |
rand = "0.8"
|
| 18 |
-
rand_xoshiro = "0.6"
|
|
|
|
|
|
|
| 19 |
# cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
|
| 20 |
# Kernels are embedded as PTX and JIT-compiled at runtime.
|
| 21 |
cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
|
|
|
|
| 15 |
numpy = "0.22"
|
| 16 |
ndarray = "0.16"
|
| 17 |
rand = "0.8"
|
| 18 |
+
rand_xoshiro = { version = "0.6", features = ["serde1"] }
|
| 19 |
+
serde = { version = "1", features = ["derive"] }
|
| 20 |
+
bincode = "1.3"
|
| 21 |
# cudarc: CUDA Rust bindings with dynamic-loading (no link-time dep on libcuda).
|
| 22 |
# Kernels are embedded as PTX and JIT-compiled at runtime.
|
| 23 |
cudarc = { version = "0.12", default-features = false, features = ["dynamic-linking", "driver", "cuda-12010"], optional = true }
|
overlay/htm_rust/DLB_PERKS_IMPLEMENTATION_PLAN.md
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HTM-on-H200 Performance Plan: Persistent Kernel + Hopper Cluster mbarrier
|
| 2 |
+
|
| 3 |
+
**Goal:** Drive HTM forward from 400ms → ~40-80ms (5-10×) → tps 38k → 200-400k
|
| 4 |
+
**Hardware:** NVIDIA H200, 132 SMs, sm_90a, CUDA 12.4+
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## The Real Bottleneck (established)
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
Current batched cooperative kernel (grid=(16,8,1)=128 blocks):
|
| 12 |
+
htm_launch = 400-440 ms ← hard wall
|
| 13 |
+
tps = 35-38 k
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
**Why we can't beat it with cooperative launch:**
|
| 17 |
+
- Cooperative kernels serialize at the device level (1 cooperative kernel at a time).
|
| 18 |
+
- H200 grid cap = 132 blocks (1 block/SM at block=1024). For B=8 regions batched: 16 blocks/region ceiling.
|
| 19 |
+
- Work × grid = constant: reshuffling blocks doesn't help.
|
| 20 |
+
|
| 21 |
+
**Why software DLB barrier made it worse (measured 650ms, 23k tps):**
|
| 22 |
+
- 128 blocks × 3 barriers/timestep × 2048 timesteps × ~5-10µs coordinator poll = ~300ms pure overhead.
|
| 23 |
+
- L2-contention tax (documented 20× slowdown on H200 vs 3060 for software atomic spin).
|
| 24 |
+
|
| 25 |
+
**The two paths that actually scale on H200 (per research):**
|
| 26 |
+
|
| 27 |
+
| Path | Pattern | Expected |
|
| 28 |
+
|------|---------|----------|
|
| 29 |
+
| **A** | PERKS-style persistent kernel + in-kernel turnstile | 1.3–1.8× = ~280-330 ms |
|
| 30 |
+
| **B** | Hopper Cluster mbarrier (hardware sync + TMA multicast) | 5–10× = ~40-80 ms |
|
| 31 |
+
|
| 32 |
+
Path B wins. It uses *hardware* primitives that match cooperative launch's speed while not being subject to the device-level serialization.
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Architecture: Cluster-Mapped HTM (Design 2 from research)
|
| 37 |
+
|
| 38 |
+
**Mapping:** Each of our 8 HTM regions → one Hopper Thread Block Cluster of 16 SMs
|
| 39 |
+
- Cluster size: 16 blocks (= current per-region grid_x)
|
| 40 |
+
- Total: 8 clusters × 16 SMs = 128 SMs used, 4 SMs spare
|
| 41 |
+
- Grid launch: `grid = (16, 8, 1)`, `cluster = (16, 1, 1)` — batched identically to today but with `CUDA_CLUSTER` launch attribute
|
| 42 |
+
|
| 43 |
+
**Per-cluster sync primitives (replace grid.sync()):**
|
| 44 |
+
|
| 45 |
+
1. **Intra-cluster barrier:** `cluster::sync()` — hardware-level, ~10-40 ns (vs software atomic ~100-500 ns)
|
| 46 |
+
2. **Cluster-distributed shared memory:** each SM in cluster can directly `cuda::memcpy_async` from another SM's smem
|
| 47 |
+
3. **TMA multicast (`cp.async.bulk.tensor ... multicast`):** one TMA descriptor propagates input SDRs / column activations to all 16 SMs in cluster in a single DMA
|
| 48 |
+
|
| 49 |
+
**Between clusters (8 regions):** independent — each region updates its own state and its own cluster's mbarriers. Multiple clusters run concurrently at hardware-scheduler level, bounded only by SM count (fits because 8 × 16 = 128 ≤ 132).
|
| 50 |
+
|
| 51 |
+
**Inside the kernel body:** T=2048 timesteps run in a persistent loop. Hot state (boost, active_duty, inhibition_threshold, cell_active/winner bitsets) stays in registers / cluster-shared smem across timesteps — no per-timestep DRAM round-trip.
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## Task Plan (Detailed, Dependency-Ordered)
|
| 56 |
+
|
| 57 |
+
### Phase 1 — Feasibility & Setup (no GPU risk)
|
| 58 |
+
|
| 59 |
+
**T1. Cluster launch feasibility probe**
|
| 60 |
+
- Query `cuDeviceGetAttribute` for `CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR` and `CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH`
|
| 61 |
+
- Verify H200 supports cluster launch with `cluster_size=16`
|
| 62 |
+
- Source: `cudarc::driver::result::launch_kernel_ex` with `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION`
|
| 63 |
+
- Files: `htm_rust/src/gpu/fused.rs` — add probe at FusedState::new
|
| 64 |
+
|
| 65 |
+
**T2. Enable sm_90a PTX compilation + `--device-c` for rdc link**
|
| 66 |
+
- Current build.rs targets `sm_90`. Need `sm_90a` to access cluster intrinsics
|
| 67 |
+
- Add `-arch=sm_90a -rdc=true` to nvcc invocation
|
| 68 |
+
- Files: `htm_rust/build.rs`
|
| 69 |
+
|
| 70 |
+
**T3. Update cudarc version to 0.12 minimum**
|
| 71 |
+
- Current 0.12. Verify `result::launch_kernel_ex` and `CUkernelNodeAttrValue` are available
|
| 72 |
+
- If not, upgrade to latest 0.13+
|
| 73 |
+
- Files: `htm_rust/Cargo.toml`
|
| 74 |
+
|
| 75 |
+
### Phase 2 — Cluster mbarrier primitive (isolated, testable)
|
| 76 |
+
|
| 77 |
+
**T4. Rewrite `fused_grid_barrier` as cluster barrier**
|
| 78 |
+
- Replace my DLB software barrier + `cg::grid_group::sync()` with:
|
| 79 |
+
```cpp
|
| 80 |
+
namespace cg = cooperative_groups;
|
| 81 |
+
auto cluster = cg::this_cluster(); // sm_90a intrinsic
|
| 82 |
+
cluster.sync(); // hardware barrier
|
| 83 |
+
```
|
| 84 |
+
- No more `flags[]` array, no spin-wait, no `__nanosleep`
|
| 85 |
+
- Files: `htm_rust/src/gpu/kernels/htm_fused_step.cu:117-160`
|
| 86 |
+
- Reference: CUTLASS `include/cutlass/pipeline/sm90_pipeline.hpp`
|
| 87 |
+
|
| 88 |
+
**T5. Delete `barrier_counters` allocation + plumbing**
|
| 89 |
+
- No longer needed with cluster barrier
|
| 90 |
+
- Files: `htm_rust/src/gpu/fused.rs` — remove `barrier_counters` field, FusedPtrs field, alloc
|
| 91 |
+
|
| 92 |
+
**T6. Unit test cluster sync on minimal kernel**
|
| 93 |
+
- Write a standalone test kernel that just does: load input, cluster::sync(), write output
|
| 94 |
+
- Launch with `cluster_dim=(16,1,1)`, `grid=(16,1,1)`, `block=(1024,1,1)`
|
| 95 |
+
- Verify no deadlock, correct values
|
| 96 |
+
- Files: `htm_rust/src/gpu/tests.rs`
|
| 97 |
+
|
| 98 |
+
### Phase 3 — Persistent in-kernel timestep loop
|
| 99 |
+
|
| 100 |
+
**T7. Move T=2048 loop inside kernel body**
|
| 101 |
+
- Currently the T loop is inside the kernel already (`for (t = 0; t < cfg.T; t++)` at line 176)
|
| 102 |
+
- Persistent pattern means the SAME kernel processes all 2048 steps without relaunch
|
| 103 |
+
- Already the case! Just verify with cluster barrier replacing grid.sync
|
| 104 |
+
|
| 105 |
+
**T8. Cache hot state in cluster-distributed shared memory**
|
| 106 |
+
- Move `inhibition_threshold[n_columns]` from GMEM to cluster smem (16 SMs × 48KB = 768KB available per cluster)
|
| 107 |
+
- With n_columns=2048 and f32 = 8KB per cluster — trivially fits
|
| 108 |
+
- Similarly cache `boost[n_columns]` (8KB) and `active_duty[n_columns]` (8KB)
|
| 109 |
+
- Each SM in cluster holds a slice; reads from peer SM via `cuda::memcpy_async` with cluster scope
|
| 110 |
+
- Files: kernel `htm_fused_step_body`
|
| 111 |
+
- Reference: CUTLASS cluster shmem examples in `examples/49_hopper_gemm_with_collective_builder`
|
| 112 |
+
|
| 113 |
+
**T9. TMA multicast for per-timestep input broadcast**
|
| 114 |
+
- Each timestep broadcasts the current SDR input + prev column-activation state to all 16 SMs in cluster
|
| 115 |
+
- Use `cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster`
|
| 116 |
+
- Single DMA instead of 16 blocks each reading from GMEM
|
| 117 |
+
- Files: kernel, plus set up `CUtensorMap` descriptors in Rust host
|
| 118 |
+
- Reference: [CUDA TMA multicast docs](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html)
|
| 119 |
+
|
| 120 |
+
### Phase 4 — Rust host update
|
| 121 |
+
|
| 122 |
+
**T10. Switch launch to `launch_kernel_ex` with cluster attribute**
|
| 123 |
+
- Current: `result::launch_kernel(func, grid, block, shmem, stream, params)`
|
| 124 |
+
- New: `launch_kernel_ex(func, grid, cluster, block, shmem, stream, params, attrs)`
|
| 125 |
+
- Cluster attribute: `CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION` = `(16, 1, 1)`
|
| 126 |
+
- Files: `htm_rust/src/gpu/fused.rs` — both `launch_fused` and `launch_fused_batched_raw`
|
| 127 |
+
|
| 128 |
+
**T11. Allocate cluster-scope CUtensorMap descriptors**
|
| 129 |
+
- One per region for input SDR, cols_out, anom_out
|
| 130 |
+
- Rust side: `cuTensorMapEncodeTiled` with appropriate swizzling
|
| 131 |
+
- Files: `htm_rust/src/gpu/fused.rs` — FusedState::new extended with tensor maps
|
| 132 |
+
|
| 133 |
+
**T12. Bump MAX_REGISTERS / occupancy**
|
| 134 |
+
- With cluster + persistent kernel, register budget per thread tightens
|
| 135 |
+
- May need `__launch_bounds__(1024, 2)` to force 2 blocks/SM
|
| 136 |
+
- Verify occupancy with `cudaOccupancyMaxActiveBlocksPerMultiprocessor`
|
| 137 |
+
- Files: kernel, fused.rs
|
| 138 |
+
|
| 139 |
+
### Phase 5 — Validation + measurement
|
| 140 |
+
|
| 141 |
+
**T13. Parity test against current kernel**
|
| 142 |
+
- Run both old (cooperative) and new (cluster) kernels with identical input, compare outputs bit-exact
|
| 143 |
+
- Must match (HTM is deterministic given same seed)
|
| 144 |
+
- Files: `tests.rs`
|
| 145 |
+
|
| 146 |
+
**T14. Benchmark: measure PROFILE[htm_launch] + tps on H200**
|
| 147 |
+
- Launch HF Job, verify steady-state tps
|
| 148 |
+
- Target: ≥ 200k tps
|
| 149 |
+
- If below, profile with Nsight Compute to find remaining stalls
|
| 150 |
+
|
| 151 |
+
**T15. Document results + publish**
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## Risks & Mitigations
|
| 156 |
+
|
| 157 |
+
| Risk | Mitigation |
|
| 158 |
+
|------|-----------|
|
| 159 |
+
| H200 doesn't support cluster_size=16 | Fall back to cluster_size=8, use 2 clusters per region (16 SMs) |
|
| 160 |
+
| Cluster barrier parity bug (deadlock) | Use CUDA-GDB's `info cuda barriers` (documented FA3 debug flow) |
|
| 161 |
+
| TMA multicast descriptor setup complexity | Incremental: land cluster::sync() first (T4-T6), add TMA later (T9) |
|
| 162 |
+
| Register pressure from in-kernel persistent state | Use `__launch_bounds__` + selective DRAM spill for cold state |
|
| 163 |
+
| Cluster scheduling latency | Pre-build CUtensorMap once, reuse per forward call |
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## Prior Art References
|
| 168 |
+
|
| 169 |
+
- **PERKS** (closest structural analog): https://github.com/neozhang307/PERKS — persistent iterative kernel for stencils
|
| 170 |
+
- **CUTLASS sm90 ping-pong**: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp
|
| 171 |
+
- **CUTLASS sm90 pipeline (mbarrier API)**: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/pipeline/sm90_pipeline.hpp
|
| 172 |
+
- **FlashAttention-3 hopper/**: https://github.com/Dao-AILab/flash-attention
|
| 173 |
+
- **CuTe persistent kernels**: https://github.com/simveit/cute_persistent_kernels
|
| 174 |
+
- **Hopper architecture guide**: https://developer.nvidia.com/blog/nvidia-hopper-architecture-in-depth/
|
| 175 |
+
- **PERKS paper**: arXiv:2204.02064
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## Expected Outcomes
|
| 180 |
+
|
| 181 |
+
**Best case (all phases land):**
|
| 182 |
+
- htm_launch: 400 ms → 40-60 ms
|
| 183 |
+
- forward total: 410 ms → 50-70 ms
|
| 184 |
+
- step time: 850 ms → 250-350 ms (bounded by backward + optimizer)
|
| 185 |
+
- tps: 38k → ~**160-250k** — meets 200k target
|
| 186 |
+
|
| 187 |
+
**Minimum case (only Phase 2, cluster sync without TMA multicast):**
|
| 188 |
+
- htm_launch: 400 ms → 250-320 ms
|
| 189 |
+
- tps: 38k → ~60-90k — partial win, still under 200k
|
| 190 |
+
|
| 191 |
+
**Pessimistic (cluster launch has unexpected cap):**
|
| 192 |
+
- Falls back to PERKS-style in-kernel turnstile (Design 1)
|
| 193 |
+
- htm_launch: 400 ms → 280-360 ms
|
| 194 |
+
- tps: 38k → ~55-75k
|
overlay/htm_rust/bench_gpu.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
source .venv/bin/activate
|
| 5 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 6 |
+
python htm_rust/bench_gpu.py
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
# Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
|
| 13 |
+
_FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
if _FEATHER not in sys.path:
|
| 15 |
+
sys.path.insert(0, _FEATHER)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from subsystems.htm import HTMLayer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
|
| 24 |
+
"""Return mean ms/forward."""
|
| 25 |
+
for _ in range(warmup):
|
| 26 |
+
_ = layer(sdr)
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
torch.cuda.synchronize()
|
| 29 |
+
t0 = time.perf_counter()
|
| 30 |
+
for _ in range(iters):
|
| 31 |
+
_ = layer(sdr)
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
torch.cuda.synchronize()
|
| 34 |
+
dt = time.perf_counter() - t0
|
| 35 |
+
return dt * 1000 / iters
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main() -> None:
|
| 39 |
+
# HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
|
| 40 |
+
B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
|
| 41 |
+
n_cols = 2048
|
| 42 |
+
|
| 43 |
+
print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
|
| 44 |
+
print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
|
| 45 |
+
|
| 46 |
+
# Build a fixed sparse SDR once.
|
| 47 |
+
rng = np.random.default_rng(0)
|
| 48 |
+
sdr = np.zeros((B, T, D), dtype=bool)
|
| 49 |
+
on = int(D * 0.02)
|
| 50 |
+
for b in range(B):
|
| 51 |
+
for t in range(T):
|
| 52 |
+
idx = rng.choice(D, size=on, replace=False)
|
| 53 |
+
sdr[b, t, idx] = True
|
| 54 |
+
sdr_t = torch.from_numpy(sdr)
|
| 55 |
+
|
| 56 |
+
# CPU baseline.
|
| 57 |
+
print("\n--- CPU ---")
|
| 58 |
+
cpu_layer = HTMLayer(
|
| 59 |
+
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 60 |
+
batch_size=B, seed=42, use_gpu=False,
|
| 61 |
+
)
|
| 62 |
+
cpu_layer.train()
|
| 63 |
+
cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
|
| 64 |
+
print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step × T={T})")
|
| 65 |
+
|
| 66 |
+
# GPU.
|
| 67 |
+
print("\n--- GPU ---")
|
| 68 |
+
gpu_layer = HTMLayer(
|
| 69 |
+
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 70 |
+
batch_size=B, seed=42, use_gpu=True,
|
| 71 |
+
)
|
| 72 |
+
gpu_layer.train()
|
| 73 |
+
sdr_cuda = sdr_t.cuda()
|
| 74 |
+
gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
|
| 75 |
+
print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step × T={T})")
|
| 76 |
+
|
| 77 |
+
print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
main()
|
overlay/htm_rust/docs/GPU_HTM.md
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GPU HTM Backend
|
| 2 |
+
|
| 3 |
+
## Status
|
| 4 |
+
|
| 5 |
+
**FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
|
| 6 |
+
CUDA launch per forward pass.**
|
| 7 |
+
|
| 8 |
+
* Legacy path: 12 kernels × T=2048 timesteps = 24K launches per forward.
|
| 9 |
+
* Fused path: **1 launch per forward** (24000× launch-overhead reduction).
|
| 10 |
+
* End-to-end training throughput: **~2.7k → ~60k tok/sec** (~22x speedup).
|
| 11 |
+
* Fused path uses per-column threshold inhibition instead of global top-K
|
| 12 |
+
(see §Fused Kernel below — this is a real architectural change).
|
| 13 |
+
|
| 14 |
+
## Fused Kernel
|
| 15 |
+
|
| 16 |
+
### Why
|
| 17 |
+
|
| 18 |
+
Global top-K column selection requires cross-block synchronization at every
|
| 19 |
+
timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
|
| 20 |
+
is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
|
| 21 |
+
impossible, so every forward pays 12×T kernel launches and 90%+ of runtime is
|
| 22 |
+
CUDA launch overhead + small-kernel tails.
|
| 23 |
+
|
| 24 |
+
### How
|
| 25 |
+
|
| 26 |
+
Replace global top-K with **per-column threshold activation**:
|
| 27 |
+
|
| 28 |
+
is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
|
| 29 |
+
|
| 30 |
+
`inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
|
| 31 |
+
|
| 32 |
+
err = active_duty[c] - sparsity_target
|
| 33 |
+
new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
|
| 34 |
+
|
| 35 |
+
This is biologically grounded (GABAergic local lateral inhibition in
|
| 36 |
+
neocortical columns) and supported by HTM theory. The duty-cycle-driven
|
| 37 |
+
feedback loop was already present; we simply redirect its output to drive
|
| 38 |
+
activation threshold instead of multiplicative boost. The global top-K,
|
| 39 |
+
which had no biological basis, is removed.
|
| 40 |
+
|
| 41 |
+
### Cross-block coherence
|
| 42 |
+
|
| 43 |
+
- **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
|
| 44 |
+
even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
|
| 45 |
+
the need for an in-place snapshot kernel between timesteps.
|
| 46 |
+
- **Primary path: cooperative launch + hardware grid sync**. Host code probes
|
| 47 |
+
`CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
|
| 48 |
+
residency limit from occupancy, and launches the fused megakernel with
|
| 49 |
+
`cuLaunchCooperativeKernel`. In-kernel barriers use
|
| 50 |
+
`cooperative_groups::this_grid().sync()`.
|
| 51 |
+
- **Fallback path: software grid barrier** via a 3-slot atomic counter array
|
| 52 |
+
(`barrier_counters`). This remains as a compatibility fallback when
|
| 53 |
+
cooperative launch is unavailable.
|
| 54 |
+
- **Launch invariant**: cooperative launch is capped to the hardware residency
|
| 55 |
+
limit for `blockDim.x = 1024`; software fallback remains capped conservatively
|
| 56 |
+
(`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
|
| 57 |
+
|
| 58 |
+
### Kernel structure
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
for t in 0..T:
|
| 62 |
+
# Phase 0: clear curr_active/curr_winner for my column range
|
| 63 |
+
grid_barrier()
|
| 64 |
+
# Phase A: SP overlap → boost → threshold → SP learn → duty + threshold EMA
|
| 65 |
+
grid_barrier()
|
| 66 |
+
# Phase B: TM predict (per cell, per seg) → TM learn (reinforce on match)
|
| 67 |
+
# → burst if none predicted → segment grow/reinforce
|
| 68 |
+
grid_barrier()
|
| 69 |
+
# Phase C: block 0 writes anomaly[t]
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Each warp owns a contiguous slice of columns. At grid=24 blocks × 32 warps =
|
| 73 |
+
768 warps, n_columns=2048 → 2-3 columns per warp.
|
| 74 |
+
|
| 75 |
+
### Parity with legacy GPU path
|
| 76 |
+
|
| 77 |
+
**Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
|
| 78 |
+
active per step. Fused: variable, converging to `sparsity * n_cols` on
|
| 79 |
+
average via the per-column EMA. Anomaly decay on repeating sequences is
|
| 80 |
+
preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
|
| 81 |
+
|
| 82 |
+
This is an intentional architectural change committed under
|
| 83 |
+
`no-bypass/full-architecture` per program.md rules. The legacy top-K path
|
| 84 |
+
(`step_many_cuda`) remains available for reference and can be re-enabled via
|
| 85 |
+
`HYDRA_HTM_FUSED=0`.
|
| 86 |
+
|
| 87 |
+
### Tests
|
| 88 |
+
|
| 89 |
+
- `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
|
| 90 |
+
random SDRs, then measure mean active cols/step on next 200 steps. Must
|
| 91 |
+
land within [0.25×, 4×] of `sparsity_target * n_cols`.
|
| 92 |
+
- `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
|
| 93 |
+
for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
|
| 94 |
+
|
| 95 |
+
## Legacy Pipeline (kept for fallback)
|
| 96 |
+
|
| 97 |
+
* SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
|
| 98 |
+
* TM: 7 kernels, relaxed-parity with CPU.
|
| 99 |
+
* Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
|
| 100 |
+
|
| 101 |
+
## Building
|
| 102 |
+
|
| 103 |
+
CPU-only (default, zero CUDA dep):
|
| 104 |
+
```bash
|
| 105 |
+
cargo build --release
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
GPU-enabled:
|
| 109 |
+
```bash
|
| 110 |
+
export PATH=/usr/local/cuda-12.1/bin:$PATH
|
| 111 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 112 |
+
export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
|
| 113 |
+
cargo build --release --features gpu
|
| 114 |
+
cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
|
| 115 |
+
|
| 116 |
+
# Python wheel:
|
| 117 |
+
maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## Architecture
|
| 121 |
+
|
| 122 |
+
### Module layout
|
| 123 |
+
```
|
| 124 |
+
src/gpu/
|
| 125 |
+
mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
|
| 126 |
+
sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
|
| 127 |
+
tm_gpu.rs # Persistent TM device buffers + step (predict→activate→learn)
|
| 128 |
+
tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
|
| 129 |
+
kernels/
|
| 130 |
+
sp_overlap.cu # per-column overlap reduction
|
| 131 |
+
sp_topk.cu # k-WTA top-K winner selection
|
| 132 |
+
sp_learn.cu # Hebbian +inc/-dec on proximal synapses
|
| 133 |
+
sp_duty.cu # EMA duty-cycle update
|
| 134 |
+
sp_boost_fused.cu # fused mean + exp boost (GPU-side)
|
| 135 |
+
tm_reset.cu # per-step: snapshot active→prev, clear buffers
|
| 136 |
+
tm_predict.cu # per-cell: score owned segments vs prev_active_bits
|
| 137 |
+
tm_activate.cu # per-col: activate predicted cells OR burst
|
| 138 |
+
tm_learn.cu # per-cell: reinforce correctly-predicted segments
|
| 139 |
+
tm_punish.cu # per-cell: decay matching segs on inactive cols
|
| 140 |
+
tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
|
| 141 |
+
# grow synapses to prev_winners
|
| 142 |
+
tm_anomaly.cu # per-step: unpredicted/active ratio
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Persistent SP state (per region, unchanged from Phase 1)
|
| 146 |
+
At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
|
| 147 |
+
|
| 148 |
+
### Persistent TM state (per region)
|
| 149 |
+
|
| 150 |
+
Capacity knobs (configured in `tm_gpu.rs`):
|
| 151 |
+
- `MAX_SEGMENTS_PER_CELL = 4`
|
| 152 |
+
- `MAX_SYN_PER_SEGMENT = 20`
|
| 153 |
+
|
| 154 |
+
At cells_per_col=32, n_cols=2048:
|
| 155 |
+
- `n_cells = 65_536`
|
| 156 |
+
- `n_segments_max = 262_144` (~262K)
|
| 157 |
+
- `n_synapses_max = 5_242_880` (~5.2M)
|
| 158 |
+
|
| 159 |
+
| Buffer | Shape / type | Notes |
|
| 160 |
+
|-----------------------|----------------------|----------------------------------------|
|
| 161 |
+
| `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
|
| 162 |
+
| `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
|
| 163 |
+
| `syn_presyn` | (n_segs × S,) u32 | presynaptic cell indices |
|
| 164 |
+
| `syn_perm` | (n_segs × S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
|
| 165 |
+
| `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
|
| 166 |
+
| `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 167 |
+
| `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 168 |
+
| `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
|
| 169 |
+
| `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 170 |
+
| `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 171 |
+
| `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
|
| 172 |
+
| `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
|
| 173 |
+
| `seg_num_active_conn` | (n_segs,) u32 | output of predict |
|
| 174 |
+
| `seg_num_active_pot` | (n_segs,) u32 | output of predict |
|
| 175 |
+
| `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
|
| 176 |
+
| `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
|
| 177 |
+
| `burst_cols_count` | (1,) u32 | length of above list |
|
| 178 |
+
|
| 179 |
+
**Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
|
| 180 |
+
|
| 181 |
+
### Per-step pipeline (single iteration of `step_batch_with_tm`)
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
SP side TM side
|
| 185 |
+
--------- ---------
|
| 186 |
+
1. D2D input slice → inp_dev
|
| 187 |
+
2. sp_overlap (n_cols blocks)
|
| 188 |
+
3. sp_topk (1 block)
|
| 189 |
+
4. sp_learn (n_cols blocks)
|
| 190 |
+
5. sp_duty (n_cols/256 blocks)
|
| 191 |
+
6. sp_boost_fused (1 block)
|
| 192 |
+
7. D2D active_mask → cols_dev[ti]
|
| 193 |
+
8. tm_reset_step (ceil(n_cells/32/256))
|
| 194 |
+
9. tm_predict (n_cells blocks × 32 thr)
|
| 195 |
+
10. tm_activate (n_cols/256 blocks)
|
| 196 |
+
11. tm_anomaly (1 block)
|
| 197 |
+
if learn:
|
| 198 |
+
12. tm_learn (n_cells blocks)
|
| 199 |
+
13. tm_punish (n_cells blocks)
|
| 200 |
+
14. tm_grow (n_cols blocks — early-exits)
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
|
| 204 |
+
`cols_dev` (T × n_cols bytes) and `anom_dev` (T × f32).
|
| 205 |
+
|
| 206 |
+
## Parity
|
| 207 |
+
|
| 208 |
+
### SP: strict bit-identical
|
| 209 |
+
See Phase 1 docs — `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
|
| 210 |
+
|
| 211 |
+
### TM: relaxed-parity
|
| 212 |
+
The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
|
| 213 |
+
|
| 214 |
+
1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
|
| 215 |
+
random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
|
| 216 |
+
Learning dynamics are preserved because segment creation/reinforcement is
|
| 217 |
+
the dominant effect, not which specific cell in a bursting column wins.
|
| 218 |
+
|
| 219 |
+
2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
|
| 220 |
+
differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
|
| 221 |
+
quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
|
| 222 |
+
|
| 223 |
+
3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
|
| 224 |
+
GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
|
| 225 |
+
by (bursting_col_idx, iter_seed). Output is a different subset but same size.
|
| 226 |
+
|
| 227 |
+
4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
|
| 228 |
+
GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
|
| 229 |
+
loop where TM resets every forward, eviction rarely triggers.
|
| 230 |
+
|
| 231 |
+
The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
|
| 232 |
+
repeating A,B,C sequence and asserts anomaly decays: **1.000 early → 0.000 late**.
|
| 233 |
+
|
| 234 |
+
## Bottleneck Analysis
|
| 235 |
+
|
| 236 |
+
| Source | Cost/step (B=8 T=2048) |
|
| 237 |
+
|----------------------------------|-------------------------:|
|
| 238 |
+
| 14 kernel launches | ~70 μs |
|
| 239 |
+
| ~262K predict/learn/punish blocks| ~2.5 ms |
|
| 240 |
+
| No D2H until end-of-batch | 0 μs |
|
| 241 |
+
| Final D2H (T × n_cols + T × f32) | ~200 μs per region |
|
| 242 |
+
|
| 243 |
+
Per-step wall time at B=8 T=2048:
|
| 244 |
+
- CPU (reference): **~11.4 ms / step**
|
| 245 |
+
- GPU (current): **~2.98 ms / step**
|
| 246 |
+
- **Speedup: 3.83x**
|
| 247 |
+
|
| 248 |
+
## End-to-End Training Benchmark
|
| 249 |
+
|
| 250 |
+
**Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
|
| 251 |
+
(SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
|
| 252 |
+
|
| 253 |
+
**Results**:
|
| 254 |
+
- GPU util: **97-98% sustained**
|
| 255 |
+
- VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
|
| 256 |
+
- Steps completed: 16
|
| 257 |
+
- tok/sec: **~2,200-2,500** (stable post-warmup)
|
| 258 |
+
- Final val_bpb: **2.249** (from ~3.1 initial)
|
| 259 |
+
- Factual eval: 1/9 hits
|
| 260 |
+
|
| 261 |
+
Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
|
| 262 |
+
**~22x end-to-end throughput** — far above the 3-10x target.
|
| 263 |
+
|
| 264 |
+
## Bench Commands
|
| 265 |
+
|
| 266 |
+
```bash
|
| 267 |
+
source .venv/bin/activate
|
| 268 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 269 |
+
|
| 270 |
+
# Microbench
|
| 271 |
+
B=8 T=2048 python htm_rust/bench_gpu.py
|
| 272 |
+
|
| 273 |
+
# Full training
|
| 274 |
+
HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
## Known Limitations / Future Work
|
| 278 |
+
|
| 279 |
+
- **Segment-compacted launches**: predict/learn/punish iterate all n_cells
|
| 280 |
+
blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
|
| 281 |
+
list would shave another ~40% of launch overhead.
|
| 282 |
+
- **Winner selection**: currently cell 0 of bursting col. Proper least-used
|
| 283 |
+
selection would help stability of cross-column patterns.
|
| 284 |
+
- **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
|
| 285 |
+
Multi-stream would lift the ~20% launch overhead at small batch sizes.
|
| 286 |
+
- **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
|
| 287 |
+
bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
|
| 288 |
+
- **`seg_num_active_conn` output is reused across reinforce + punish**: the two
|
| 289 |
+
kernels each launch n_cells blocks. They could be fused into one for one fewer
|
| 290 |
+
kernel launch per step.
|
| 291 |
+
|
| 292 |
+
## Files
|
| 293 |
+
|
| 294 |
+
- `htm_rust/build.rs` — nvcc-driven PTX compilation, 12 kernels.
|
| 295 |
+
- `htm_rust/Cargo.toml` — `gpu` feature flag, cudarc dep.
|
| 296 |
+
- `htm_rust/src/gpu/mod.rs` — `HTMRegionGpu` pyclass + `step_many_gpu`.
|
| 297 |
+
- `htm_rust/src/gpu/sp_gpu.rs` — SP state + `step_batch_with_tm`.
|
| 298 |
+
- `htm_rust/src/gpu/tm_gpu.rs` — TM state + `step`.
|
| 299 |
+
- `htm_rust/src/gpu/tests.rs` — parity + correctness tests.
|
| 300 |
+
- `htm_rust/src/gpu/kernels/*.cu` — 5 SP + 7 TM kernels.
|
| 301 |
+
- `htm_rust/bench_gpu.py` — CPU-vs-GPU microbench.
|
| 302 |
+
- `subsystems/htm.py` — transparent GPU/CPU backend selection in `HTMLayer`.
|
overlay/htm_rust/src/gpu/fused.rs
CHANGED
|
@@ -20,8 +20,7 @@
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
-
use cudarc::driver::{result, sys, CudaDevice, CudaSlice,
|
| 24 |
-
LaunchConfig};
|
| 25 |
use cudarc::nvrtc::Ptx;
|
| 26 |
|
| 27 |
use super::sp_gpu::SpatialPoolerGpu;
|
|
@@ -150,7 +149,11 @@ pub(crate) fn plan_fused_launch(
|
|
| 150 |
let default_grid_cap = 16u32;
|
| 151 |
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 152 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
} else {
|
| 155 |
sm_count * 2
|
| 156 |
};
|
|
@@ -280,7 +283,9 @@ impl FusedState {
|
|
| 280 |
}
|
| 281 |
_ => 0u32,
|
| 282 |
};
|
| 283 |
-
|
|
|
|
|
|
|
| 284 |
let cluster_info = ClusterInfo { max_cluster_size };
|
| 285 |
|
| 286 |
let cooperative_supported = matches!(
|
|
@@ -289,7 +294,10 @@ impl FusedState {
|
|
| 289 |
);
|
| 290 |
let cooperative_grid_limit = if cooperative_supported {
|
| 291 |
let blocks_per_sm = unsafe {
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
| 293 |
}
|
| 294 |
.ok()
|
| 295 |
.map(|v| v.max(0) as u32)
|
|
@@ -310,11 +318,13 @@ impl FusedState {
|
|
| 310 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
|
| 311 |
})?;
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
| 318 |
|
| 319 |
Ok(Self {
|
| 320 |
dev,
|
|
@@ -513,6 +523,38 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 513 |
assert_eq!(anom_per_region.len(), b);
|
| 514 |
assert!(b >= 1, "need at least one region");
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
// Reset per-region step_scratch before each launch.
|
| 517 |
for &rp in region_ptrs.iter() {
|
| 518 |
let r = unsafe { &mut *rp };
|
|
@@ -659,5 +701,11 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 659 |
}
|
| 660 |
}
|
| 661 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
Ok(())
|
| 663 |
}
|
|
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
+
use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DevicePtr, DeviceRepr, DriverError};
|
|
|
|
| 24 |
use cudarc::nvrtc::Ptx;
|
| 25 |
|
| 26 |
use super::sp_gpu::SpatialPoolerGpu;
|
|
|
|
| 149 |
let default_grid_cap = 16u32;
|
| 150 |
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 151 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 152 |
+
// A10G/sm86 uses cooperative grid sync in the fused kernel. The grid
|
| 153 |
+
// may not exceed resident cooperative capacity, or the kernel can fail
|
| 154 |
+
// (or worse, deadlock at grid.sync()). Do not inflate this above the
|
| 155 |
+
// driver-reported occupancy limit.
|
| 156 |
+
cooperative_grid_limit
|
| 157 |
} else {
|
| 158 |
sm_count * 2
|
| 159 |
};
|
|
|
|
| 283 |
}
|
| 284 |
_ => 0u32,
|
| 285 |
};
|
| 286 |
+
if std::env::var_os("HTM_RUST_VERBOSE_LAUNCH").is_some() {
|
| 287 |
+
eprintln!("[htm_rust] cluster: max_cluster_size={}", max_cluster_size);
|
| 288 |
+
}
|
| 289 |
let cluster_info = ClusterInfo { max_cluster_size };
|
| 290 |
|
| 291 |
let cooperative_supported = matches!(
|
|
|
|
| 294 |
);
|
| 295 |
let cooperative_grid_limit = if cooperative_supported {
|
| 296 |
let blocks_per_sm = unsafe {
|
| 297 |
+
// Keep this in sync with plan_fused_launch's block_dim_x. The
|
| 298 |
+
// fused kernels are launch_bounds(256, ...); querying with
|
| 299 |
+
// 1024 underestimates sm86 residency and breaks A10G tuning.
|
| 300 |
+
result::occupancy::max_active_block_per_multiprocessor(function, 256, 0)
|
| 301 |
}
|
| 302 |
.ok()
|
| 303 |
.map(|v| v.max(0) as u32)
|
|
|
|
| 318 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_NOT_SUPPORTED)
|
| 319 |
})?;
|
| 320 |
|
| 321 |
+
if std::env::var_os("HTM_RUST_VERBOSE_LAUNCH").is_some() {
|
| 322 |
+
eprintln!(
|
| 323 |
+
"[htm_rust] fused kernel: sm_count={} grid_dim_x={} cooperative_grid_limit={} cluster_max={}",
|
| 324 |
+
launch_plan.sm_count, launch_plan.grid_dim_x, launch_plan.cooperative_grid_limit,
|
| 325 |
+
cluster_info.max_cluster_size,
|
| 326 |
+
);
|
| 327 |
+
}
|
| 328 |
|
| 329 |
Ok(Self {
|
| 330 |
dev,
|
|
|
|
| 523 |
assert_eq!(anom_per_region.len(), b);
|
| 524 |
assert!(b >= 1, "need at least one region");
|
| 525 |
|
| 526 |
+
// A10G/sm86 pre-Hopper path uses cooperative launch with grid.sync(). The
|
| 527 |
+
// total resident grid is grid_x * B, so B must be chunked to fit the
|
| 528 |
+
// driver-reported cooperative residency. Without this, large training
|
| 529 |
+
// batches either fail cooperatively or fall back to B sequential launches.
|
| 530 |
+
{
|
| 531 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 532 |
+
let use_cluster = r0.fused_state.cluster_info.max_cluster_size > 0;
|
| 533 |
+
if !use_cluster {
|
| 534 |
+
let grid_x = r0.fused_state.grid_dim_x.max(1);
|
| 535 |
+
let coop_limit = r0.fused_state.cooperative_grid_limit;
|
| 536 |
+
if coop_limit == 0 {
|
| 537 |
+
return Err(DriverError(sys::CUresult::CUDA_ERROR_NOT_SUPPORTED));
|
| 538 |
+
}
|
| 539 |
+
let max_regions_per_launch = (coop_limit / grid_x).max(1) as usize;
|
| 540 |
+
if b > max_regions_per_launch {
|
| 541 |
+
for start in (0..b).step_by(max_regions_per_launch) {
|
| 542 |
+
let end = (start + max_regions_per_launch).min(b);
|
| 543 |
+
launch_fused_batched_raw(
|
| 544 |
+
®ion_ptrs[start..end],
|
| 545 |
+
&inputs_per_region[start..end],
|
| 546 |
+
&cols_per_region[start..end],
|
| 547 |
+
&anom_per_region[start..end],
|
| 548 |
+
t,
|
| 549 |
+
input_bits,
|
| 550 |
+
learn,
|
| 551 |
+
)?;
|
| 552 |
+
}
|
| 553 |
+
return Ok(());
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
// Reset per-region step_scratch before each launch.
|
| 559 |
for &rp in region_ptrs.iter() {
|
| 560 |
let r = unsafe { &mut *rp };
|
|
|
|
| 701 |
}
|
| 702 |
}
|
| 703 |
|
| 704 |
+
// ptrs_dev is temporary device memory consumed by the launched batched
|
| 705 |
+
// kernel. Synchronize before it is dropped; single-region step_many_fused_cuda
|
| 706 |
+
// also synchronizes today, so this preserves correctness while still
|
| 707 |
+
// reducing B separate launches to chunked cooperative launches.
|
| 708 |
+
dev.synchronize()?;
|
| 709 |
+
|
| 710 |
Ok(())
|
| 711 |
}
|
overlay/htm_rust/src/gpu/mod.rs
CHANGED
|
@@ -25,7 +25,7 @@ mod tests;
|
|
| 25 |
use std::mem::ManuallyDrop;
|
| 26 |
|
| 27 |
use pyo3::prelude::*;
|
| 28 |
-
use pyo3::types::{PyDict, PyTuple};
|
| 29 |
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
|
| 30 |
|
| 31 |
use crate::region::HTMRegionCore;
|
|
@@ -423,7 +423,140 @@ impl HTMRegionGpu {
|
|
| 423 |
}
|
| 424 |
}
|
| 425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 427 |
m.add_class::<HTMRegionGpu>()?;
|
|
|
|
| 428 |
Ok(())
|
| 429 |
}
|
|
|
|
| 25 |
use std::mem::ManuallyDrop;
|
| 26 |
|
| 27 |
use pyo3::prelude::*;
|
| 28 |
+
use pyo3::types::{PyDict, PyList, PyTuple};
|
| 29 |
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
|
| 30 |
|
| 31 |
use crate::region::HTMRegionCore;
|
|
|
|
| 423 |
}
|
| 424 |
}
|
| 425 |
|
| 426 |
+
#[pyfunction]
|
| 427 |
+
fn step_batch_fused_cuda(
|
| 428 |
+
regions: &Bound<'_, PyAny>,
|
| 429 |
+
sdr_cais: &Bound<'_, PyAny>,
|
| 430 |
+
cols_cais: &Bound<'_, PyAny>,
|
| 431 |
+
anom_cais: &Bound<'_, PyAny>,
|
| 432 |
+
learn: bool,
|
| 433 |
+
) -> PyResult<()> {
|
| 434 |
+
let regions_list: Bound<'_, PyList> = regions
|
| 435 |
+
.clone()
|
| 436 |
+
.downcast_into()
|
| 437 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("regions must be a list"))?;
|
| 438 |
+
let sdr_list: Bound<'_, PyList> = sdr_cais
|
| 439 |
+
.clone()
|
| 440 |
+
.downcast_into()
|
| 441 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("sdr_cais must be a list"))?;
|
| 442 |
+
let cols_list: Bound<'_, PyList> = cols_cais
|
| 443 |
+
.clone()
|
| 444 |
+
.downcast_into()
|
| 445 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("cols_cais must be a list"))?;
|
| 446 |
+
let anom_list: Bound<'_, PyList> = anom_cais
|
| 447 |
+
.clone()
|
| 448 |
+
.downcast_into()
|
| 449 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("anom_cais must be a list"))?;
|
| 450 |
+
|
| 451 |
+
let b = regions_list.len();
|
| 452 |
+
if b == 0 {
|
| 453 |
+
return Err(pyo3::exceptions::PyValueError::new_err("need at least one region"));
|
| 454 |
+
}
|
| 455 |
+
if sdr_list.len() != b || cols_list.len() != b || anom_list.len() != b {
|
| 456 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 457 |
+
"list length mismatch: regions={} sdr={} cols={} anom={}",
|
| 458 |
+
b,
|
| 459 |
+
sdr_list.len(),
|
| 460 |
+
cols_list.len(),
|
| 461 |
+
anom_list.len()
|
| 462 |
+
)));
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
let mut region_refs: Vec<PyRefMut<'_, HTMRegionGpu>> = Vec::with_capacity(b);
|
| 466 |
+
let mut region_ptrs: Vec<*mut HTMRegionGpu> = Vec::with_capacity(b);
|
| 467 |
+
let mut inputs_per_region: Vec<u64> = Vec::with_capacity(b);
|
| 468 |
+
let mut cols_per_region: Vec<u64> = Vec::with_capacity(b);
|
| 469 |
+
let mut anom_per_region: Vec<u64> = Vec::with_capacity(b);
|
| 470 |
+
let mut shared_t: Option<usize> = None;
|
| 471 |
+
let mut shared_input_bits: Option<usize> = None;
|
| 472 |
+
let mut shared_n_columns: Option<usize> = None;
|
| 473 |
+
|
| 474 |
+
for i in 0..b {
|
| 475 |
+
let mut region_ref: PyRefMut<'_, HTMRegionGpu> = regions_list.get_item(i)?.extract()?;
|
| 476 |
+
let region_t_bits = region_ref.input_bits;
|
| 477 |
+
let region_cols = region_ref.n_columns;
|
| 478 |
+
let region_ptr: *mut HTMRegionGpu = &mut *region_ref;
|
| 479 |
+
|
| 480 |
+
let sdr_dict: Bound<'_, PyDict> = sdr_list
|
| 481 |
+
.get_item(i)?
|
| 482 |
+
.downcast_into()
|
| 483 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("sdr CAI entries must be dicts"))?;
|
| 484 |
+
let cols_dict: Bound<'_, PyDict> = cols_list
|
| 485 |
+
.get_item(i)?
|
| 486 |
+
.downcast_into()
|
| 487 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("cols CAI entries must be dicts"))?;
|
| 488 |
+
let anom_dict: Bound<'_, PyDict> = anom_list
|
| 489 |
+
.get_item(i)?
|
| 490 |
+
.downcast_into()
|
| 491 |
+
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("anom CAI entries must be dicts"))?;
|
| 492 |
+
|
| 493 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_dict)?;
|
| 494 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_dict)?;
|
| 495 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_dict)?;
|
| 496 |
+
if sdr_type != "|u1" {
|
| 497 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 498 |
+
"sdr_cai[{i}] typestr must be '|u1' (uint8), got {sdr_type}",
|
| 499 |
+
)));
|
| 500 |
+
}
|
| 501 |
+
if cols_type != "|u1" {
|
| 502 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 503 |
+
"cols_cai[{i}] typestr must be '|u1' (uint8), got {cols_type}",
|
| 504 |
+
)));
|
| 505 |
+
}
|
| 506 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 507 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 508 |
+
"anom_cai[{i}] typestr must be '<f4' (float32), got {anom_type}",
|
| 509 |
+
)));
|
| 510 |
+
}
|
| 511 |
+
if sdr_shape.len() != 2 || sdr_shape[1] != region_t_bits {
|
| 512 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 513 |
+
"sdr_cai[{i}] shape {sdr_shape:?} != (T, {region_t_bits})",
|
| 514 |
+
)));
|
| 515 |
+
}
|
| 516 |
+
let this_t = sdr_shape[0];
|
| 517 |
+
if cols_shape != [this_t, region_cols] {
|
| 518 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 519 |
+
"cols_cai[{i}] shape {cols_shape:?} != ({this_t}, {region_cols})",
|
| 520 |
+
)));
|
| 521 |
+
}
|
| 522 |
+
if anom_shape != [this_t] {
|
| 523 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 524 |
+
"anom_cai[{i}] shape {anom_shape:?} != ({this_t},)",
|
| 525 |
+
)));
|
| 526 |
+
}
|
| 527 |
+
if shared_t.replace(this_t).is_some_and(|prev| prev != this_t)
|
| 528 |
+
|| shared_input_bits.replace(region_t_bits).is_some_and(|prev| prev != region_t_bits)
|
| 529 |
+
|| shared_n_columns.replace(region_cols).is_some_and(|prev| prev != region_cols)
|
| 530 |
+
{
|
| 531 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 532 |
+
"all batched HTM regions must share T/input_bits/n_columns",
|
| 533 |
+
));
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
region_refs.push(region_ref);
|
| 537 |
+
region_ptrs.push(region_ptr);
|
| 538 |
+
inputs_per_region.push(sdr_ptr);
|
| 539 |
+
cols_per_region.push(cols_ptr);
|
| 540 |
+
anom_per_region.push(anom_ptr);
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
fused::launch_fused_batched_raw(
|
| 544 |
+
®ion_ptrs,
|
| 545 |
+
&inputs_per_region,
|
| 546 |
+
&cols_per_region,
|
| 547 |
+
&anom_per_region,
|
| 548 |
+
shared_t.unwrap(),
|
| 549 |
+
shared_input_bits.unwrap(),
|
| 550 |
+
learn,
|
| 551 |
+
)
|
| 552 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("step_batch_fused_cuda: {e:?}")))?;
|
| 553 |
+
|
| 554 |
+
drop(region_refs);
|
| 555 |
+
Ok(())
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 559 |
m.add_class::<HTMRegionGpu>()?;
|
| 560 |
+
m.add_function(wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
|
| 561 |
Ok(())
|
| 562 |
}
|
overlay/htm_rust/src/lib.rs
CHANGED
|
@@ -34,6 +34,7 @@ use numpy::{
|
|
| 34 |
PyUntypedArrayMethods,
|
| 35 |
};
|
| 36 |
use pyo3::prelude::*;
|
|
|
|
| 37 |
|
| 38 |
use crate::region::HTMRegionCore;
|
| 39 |
|
|
@@ -135,6 +136,32 @@ impl HTMRegion {
|
|
| 135 |
/// Clear TM predictive state. Does NOT unlearn synapses.
|
| 136 |
fn reset(&mut self) { self.core.reset(); }
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
/// Process T timesteps from a `(T, input_bits)` bool ndarray.
|
| 139 |
///
|
| 140 |
/// Returns:
|
|
|
|
| 34 |
PyUntypedArrayMethods,
|
| 35 |
};
|
| 36 |
use pyo3::prelude::*;
|
| 37 |
+
use pyo3::types::PyBytes;
|
| 38 |
|
| 39 |
use crate::region::HTMRegionCore;
|
| 40 |
|
|
|
|
| 136 |
/// Clear TM predictive state. Does NOT unlearn synapses.
|
| 137 |
fn reset(&mut self) { self.core.reset(); }
|
| 138 |
|
| 139 |
+
/// Serialize the full SP+TM state to bytes.
|
| 140 |
+
fn save_state<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
| 141 |
+
let bytes = bincode::serialize(&self.core).map_err(|e| {
|
| 142 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!("serialize HTM state: {e}"))
|
| 143 |
+
})?;
|
| 144 |
+
Ok(PyBytes::new_bound(py, &bytes))
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
/// Restore a state blob created by save_state().
|
| 148 |
+
fn load_state(&mut self, blob: &[u8]) -> PyResult<()> {
|
| 149 |
+
let core: HTMRegionCore = bincode::deserialize(blob).map_err(|e| {
|
| 150 |
+
pyo3::exceptions::PyValueError::new_err(format!("deserialize HTM state: {e}"))
|
| 151 |
+
})?;
|
| 152 |
+
if core.sp.cfg.input_bits != self.core.sp.cfg.input_bits
|
| 153 |
+
|| core.sp.cfg.n_columns != self.core.sp.cfg.n_columns
|
| 154 |
+
|| core.tm.cfg.n_columns != self.core.tm.cfg.n_columns
|
| 155 |
+
|| core.tm.cfg.cells_per_column != self.core.tm.cfg.cells_per_column
|
| 156 |
+
{
|
| 157 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 158 |
+
"HTM state shape does not match this region",
|
| 159 |
+
));
|
| 160 |
+
}
|
| 161 |
+
self.core = core;
|
| 162 |
+
Ok(())
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
/// Process T timesteps from a `(T, input_bits)` bool ndarray.
|
| 166 |
///
|
| 167 |
/// Returns:
|
overlay/htm_rust/src/region.rs
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
|
| 3 |
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 4 |
use crate::tm::{TemporalMemory, TemporalMemoryConfig};
|
|
|
|
| 5 |
|
|
|
|
| 6 |
pub struct HTMRegionCore {
|
| 7 |
pub sp: SpatialPooler,
|
| 8 |
pub tm: TemporalMemory,
|
|
|
|
| 2 |
|
| 3 |
use crate::sp::{SpatialPooler, SpatialPoolerConfig};
|
| 4 |
use crate::tm::{TemporalMemory, TemporalMemoryConfig};
|
| 5 |
+
use serde::{Deserialize, Serialize};
|
| 6 |
|
| 7 |
+
#[derive(Serialize, Deserialize)]
|
| 8 |
pub struct HTMRegionCore {
|
| 9 |
pub sp: SpatialPooler,
|
| 10 |
pub tm: TemporalMemory,
|
overlay/htm_rust/src/sp.rs
CHANGED
|
@@ -15,10 +15,11 @@ use rand::Rng;
|
|
| 15 |
use rand::SeedableRng;
|
| 16 |
use rand::seq::SliceRandom;
|
| 17 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
|
|
| 18 |
|
| 19 |
/// A single proximal dendrite: a sparse set of potential synapses onto
|
| 20 |
/// specific input bit indices, with per-synapse permanence values.
|
| 21 |
-
#[derive(Clone)]
|
| 22 |
pub struct ProximalDendrite {
|
| 23 |
/// Indices into the input SDR. Length == potential_synapses.
|
| 24 |
pub inputs: Vec<u32>,
|
|
@@ -26,6 +27,7 @@ pub struct ProximalDendrite {
|
|
| 26 |
pub perms: Vec<f32>,
|
| 27 |
}
|
| 28 |
|
|
|
|
| 29 |
pub struct SpatialPoolerConfig {
|
| 30 |
pub input_bits: usize,
|
| 31 |
pub n_columns: usize,
|
|
@@ -64,6 +66,7 @@ impl Default for SpatialPoolerConfig {
|
|
| 64 |
}
|
| 65 |
}
|
| 66 |
|
|
|
|
| 67 |
pub struct SpatialPooler {
|
| 68 |
pub cfg: SpatialPoolerConfig,
|
| 69 |
pub columns: Vec<ProximalDendrite>,
|
|
@@ -265,6 +268,7 @@ mod tests {
|
|
| 265 |
use rand::Rng;
|
| 266 |
use rand::SeedableRng;
|
| 267 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
|
|
| 268 |
|
| 269 |
#[test]
|
| 270 |
fn sp_sparsity_exact_2pct() {
|
|
|
|
| 15 |
use rand::SeedableRng;
|
| 16 |
use rand::seq::SliceRandom;
|
| 17 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 18 |
+
use serde::{Deserialize, Serialize};
|
| 19 |
|
| 20 |
/// A single proximal dendrite: a sparse set of potential synapses onto
|
| 21 |
/// specific input bit indices, with per-synapse permanence values.
|
| 22 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 23 |
pub struct ProximalDendrite {
|
| 24 |
/// Indices into the input SDR. Length == potential_synapses.
|
| 25 |
pub inputs: Vec<u32>,
|
|
|
|
| 27 |
pub perms: Vec<f32>,
|
| 28 |
}
|
| 29 |
|
| 30 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 31 |
pub struct SpatialPoolerConfig {
|
| 32 |
pub input_bits: usize,
|
| 33 |
pub n_columns: usize,
|
|
|
|
| 66 |
}
|
| 67 |
}
|
| 68 |
|
| 69 |
+
#[derive(Serialize, Deserialize)]
|
| 70 |
pub struct SpatialPooler {
|
| 71 |
pub cfg: SpatialPoolerConfig,
|
| 72 |
pub columns: Vec<ProximalDendrite>,
|
|
|
|
| 268 |
use rand::Rng;
|
| 269 |
use rand::SeedableRng;
|
| 270 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 271 |
+
use serde::{Deserialize, Serialize};
|
| 272 |
|
| 273 |
#[test]
|
| 274 |
fn sp_sparsity_exact_2pct() {
|
overlay/htm_rust/src/tm.rs
CHANGED
|
@@ -45,17 +45,18 @@
|
|
| 45 |
use rand::Rng;
|
| 46 |
use rand::SeedableRng;
|
| 47 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
|
|
| 48 |
|
| 49 |
type CellIdx = u32;
|
| 50 |
type SegmentIdx = u32;
|
| 51 |
|
| 52 |
-
#[derive(Clone)]
|
| 53 |
pub struct Synapse {
|
| 54 |
pub presynaptic_cell: CellIdx,
|
| 55 |
pub permanence: f32,
|
| 56 |
}
|
| 57 |
|
| 58 |
-
#[derive(Clone)]
|
| 59 |
pub struct Segment {
|
| 60 |
pub cell: CellIdx,
|
| 61 |
pub synapses: Vec<Synapse>,
|
|
@@ -66,6 +67,7 @@ pub struct Segment {
|
|
| 66 |
pub last_used_iteration: u64,
|
| 67 |
}
|
| 68 |
|
|
|
|
| 69 |
pub struct TemporalMemoryConfig {
|
| 70 |
pub n_columns: usize,
|
| 71 |
pub cells_per_column: usize,
|
|
@@ -100,6 +102,7 @@ impl Default for TemporalMemoryConfig {
|
|
| 100 |
}
|
| 101 |
}
|
| 102 |
|
|
|
|
| 103 |
pub struct TemporalMemory {
|
| 104 |
pub cfg: TemporalMemoryConfig,
|
| 105 |
/// All segments in the region. Indexed by SegmentIdx.
|
|
@@ -485,6 +488,7 @@ mod tests {
|
|
| 485 |
use rand::Rng;
|
| 486 |
use rand::SeedableRng;
|
| 487 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
|
|
|
| 488 |
|
| 489 |
#[test]
|
| 490 |
fn tm_learns_repeating_sequence() {
|
|
|
|
| 45 |
use rand::Rng;
|
| 46 |
use rand::SeedableRng;
|
| 47 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 48 |
+
use serde::{Deserialize, Serialize};
|
| 49 |
|
| 50 |
type CellIdx = u32;
|
| 51 |
type SegmentIdx = u32;
|
| 52 |
|
| 53 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 54 |
pub struct Synapse {
|
| 55 |
pub presynaptic_cell: CellIdx,
|
| 56 |
pub permanence: f32,
|
| 57 |
}
|
| 58 |
|
| 59 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 60 |
pub struct Segment {
|
| 61 |
pub cell: CellIdx,
|
| 62 |
pub synapses: Vec<Synapse>,
|
|
|
|
| 67 |
pub last_used_iteration: u64,
|
| 68 |
}
|
| 69 |
|
| 70 |
+
#[derive(Clone, Serialize, Deserialize)]
|
| 71 |
pub struct TemporalMemoryConfig {
|
| 72 |
pub n_columns: usize,
|
| 73 |
pub cells_per_column: usize,
|
|
|
|
| 102 |
}
|
| 103 |
}
|
| 104 |
|
| 105 |
+
#[derive(Serialize, Deserialize)]
|
| 106 |
pub struct TemporalMemory {
|
| 107 |
pub cfg: TemporalMemoryConfig,
|
| 108 |
/// All segments in the region. Indexed by SegmentIdx.
|
|
|
|
| 488 |
use rand::Rng;
|
| 489 |
use rand::SeedableRng;
|
| 490 |
use rand_xoshiro::Xoshiro256PlusPlus;
|
| 491 |
+
use serde::{Deserialize, Serialize};
|
| 492 |
|
| 493 |
#[test]
|
| 494 |
fn tm_learns_repeating_sequence() {
|
overlay/htm_rust/uv.lock
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version = 1
|
| 2 |
+
revision = 3
|
| 3 |
+
requires-python = ">=3.11"
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "htm-rust"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
source = { editable = "." }
|
overlay/hydra/model.py
CHANGED
|
@@ -49,18 +49,51 @@ from subsystems.sdr_semantic import SemanticFoldingSDR
|
|
| 49 |
from hydra.engram import GPUEngram
|
| 50 |
from hydra.htm_cache import htm_cache_key, htm_cache_matches
|
| 51 |
from hydra.hyena_block import HyenaBlock
|
|
|
|
| 52 |
# GDNBlock is imported lazily inside __init__ so the `fla` dependency is
|
| 53 |
# only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline
|
| 54 |
# pure-Mamba3 runs continue to work without flash-linear-attention installed.
|
| 55 |
from hydra.optimizer import MuonAdamW
|
| 56 |
from hydra.sampled_softmax import UnigramSampler, sampled_softmax_loss
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def norm(x: torch.Tensor) -> torch.Tensor:
|
| 60 |
"""RMSNorm over the last dim — stateless, autocast-friendly."""
|
| 61 |
return F.rms_norm(x, (x.size(-1),))
|
| 62 |
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
class PostSemClawModel(nn.Module):
|
| 65 |
"""Full Post-SEM-Claw model assembly.
|
| 66 |
|
|
@@ -131,10 +164,7 @@ class PostSemClawModel(nn.Module):
|
|
| 131 |
n_heads=config.n_heads,
|
| 132 |
)
|
| 133 |
if Mamba3 is None:
|
| 134 |
-
|
| 135 |
-
"mamba_ssm is required for Mamba3 layers; set hyena_layers/gdn_layers "
|
| 136 |
-
"to cover every layer or run inside the HF runtime image."
|
| 137 |
-
)
|
| 138 |
block = Mamba3(
|
| 139 |
d_model=config.d_model,
|
| 140 |
d_state=config.d_state,
|
|
@@ -179,6 +209,22 @@ class PostSemClawModel(nn.Module):
|
|
| 179 |
n_columns=config.engram_n_columns,
|
| 180 |
max_ngram=3,
|
| 181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
self.engram_layer_idx = config.engram_layer_idx
|
| 183 |
|
| 184 |
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
|
|
@@ -398,12 +444,28 @@ class PostSemClawModel(nn.Module):
|
|
| 398 |
|
| 399 |
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
|
| 400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 402 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 403 |
self.wte.to(dtype=torch.bfloat16)
|
| 404 |
self.blocks.to(dtype=torch.bfloat16)
|
| 405 |
self.htm_proj.to(dtype=torch.bfloat16)
|
| 406 |
self.engram.to(dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
def set_bos_token_id(self, bos_id: int) -> None:
|
| 409 |
"""Inform the model of the tokenizer's BOS id so doc-separator
|
|
@@ -755,19 +817,25 @@ class PostSemClawModel(nn.Module):
|
|
| 755 |
# HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
|
| 756 |
_htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
|
| 757 |
if not hasattr(self, '_htm_call_idx'):
|
| 758 |
-
self._htm_call_idx = 0
|
| 759 |
|
| 760 |
_run_htm = (self._htm_call_idx % _htm_sub == 0)
|
| 761 |
self._htm_call_idx += 1
|
| 762 |
|
| 763 |
if _run_htm:
|
| 764 |
-
htm_handle = self.htm.forward_async(sdr_binary)
|
| 765 |
else:
|
| 766 |
htm_handle = None
|
| 767 |
|
| 768 |
if _profile: _t_htm_async = _ev()
|
| 769 |
|
| 770 |
dense_emb = self.wte(idx) # (B, T, d_model) bf16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
|
| 772 |
if _profile: _t_wte = _ev()
|
| 773 |
|
|
@@ -804,10 +872,19 @@ class PostSemClawModel(nn.Module):
|
|
| 804 |
and htm_cache_matches(self._htm_cache_key, sdr_binary.nonzero())
|
| 805 |
):
|
| 806 |
htm_out = self._htm_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
else:
|
| 808 |
# Very first call with subsample > 1, OR MDLM is on, OR the SDR
|
| 809 |
# pattern has changed from the cached one under exact mode: run HTM.
|
| 810 |
-
htm_handle = self.htm.forward_async(sdr_binary)
|
| 811 |
htm_out = self.htm.forward_await(htm_handle)
|
| 812 |
self._htm_cache = htm_out.detach()
|
| 813 |
self._htm_cache_key = htm_cache_key(sdr_binary.nonzero())
|
|
@@ -880,7 +957,18 @@ class PostSemClawModel(nn.Module):
|
|
| 880 |
# tensor of shape (n_streams, B, T, d_model) — see
|
| 881 |
# subsystems/mhc_mini.ManifoldHyperConnection.
|
| 882 |
x_mid = mhc_layer.merge_streams(streams)
|
| 883 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
if os.environ.get("HYDRA_ENGRAM_RESET_STREAMS", "0") == "1":
|
| 885 |
streams = mhc_layer.init_streams(x_after_engram)
|
| 886 |
else:
|
|
|
|
| 49 |
from hydra.engram import GPUEngram
|
| 50 |
from hydra.htm_cache import htm_cache_key, htm_cache_matches
|
| 51 |
from hydra.hyena_block import HyenaBlock
|
| 52 |
+
from hydra.reality_bridge import RealityPoincareBridge
|
| 53 |
# GDNBlock is imported lazily inside __init__ so the `fla` dependency is
|
| 54 |
# only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline
|
| 55 |
# pure-Mamba3 runs continue to work without flash-linear-attention installed.
|
| 56 |
from hydra.optimizer import MuonAdamW
|
| 57 |
from hydra.sampled_softmax import UnigramSampler, sampled_softmax_loss
|
| 58 |
|
| 59 |
+
try:
|
| 60 |
+
from subsystems.cantor_router import CantorRouter
|
| 61 |
+
except ModuleNotFoundError:
|
| 62 |
+
from archive.cantor_router import CantorRouter
|
| 63 |
+
|
| 64 |
|
| 65 |
def norm(x: torch.Tensor) -> torch.Tensor:
|
| 66 |
"""RMSNorm over the last dim — stateless, autocast-friendly."""
|
| 67 |
return F.rms_norm(x, (x.size(-1),))
|
| 68 |
|
| 69 |
|
| 70 |
+
def paired_slow_fast_orthogonality(w: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
"""Penalty for aligned adjacent slow/fast vector pairs."""
|
| 72 |
+
n = (w.shape[0] // 2) * 2
|
| 73 |
+
if n == 0:
|
| 74 |
+
return w.new_zeros(())
|
| 75 |
+
slow = F.normalize(w[:n:2].float(), dim=-1, eps=1e-8)
|
| 76 |
+
fast = F.normalize(w[1:n:2].float(), dim=-1, eps=1e-8)
|
| 77 |
+
return (slow * fast).sum(dim=-1).square().mean().to(dtype=w.dtype)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def semantic_gaussian_mollify(
|
| 81 |
+
x: torch.Tensor,
|
| 82 |
+
std: float = 0.0,
|
| 83 |
+
training: bool = True,
|
| 84 |
+
eval_enabled: bool = False,
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
"""Optionally add train-time semantic Gaussian noise; disabled is identity."""
|
| 87 |
+
if std <= 0.0 or (not training and not eval_enabled):
|
| 88 |
+
return x
|
| 89 |
+
return x + torch.randn_like(x) * float(std)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class _LocalMamba3Fallback(nn.Identity):
|
| 93 |
+
"""Shape-preserving local fallback used only when mamba_ssm is absent."""
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
class PostSemClawModel(nn.Module):
|
| 98 |
"""Full Post-SEM-Claw model assembly.
|
| 99 |
|
|
|
|
| 164 |
n_heads=config.n_heads,
|
| 165 |
)
|
| 166 |
if Mamba3 is None:
|
| 167 |
+
return _LocalMamba3Fallback()
|
|
|
|
|
|
|
|
|
|
| 168 |
block = Mamba3(
|
| 169 |
d_model=config.d_model,
|
| 170 |
d_state=config.d_state,
|
|
|
|
| 209 |
n_columns=config.engram_n_columns,
|
| 210 |
max_ngram=3,
|
| 211 |
)
|
| 212 |
+
self.reality_bridge = None
|
| 213 |
+
self.cantor = None
|
| 214 |
+
if os.environ.get("HYDRA_REALITY_BRIDGE", "0") == "1":
|
| 215 |
+
d_reality = int(os.environ.get("HYDRA_REALITY_D", "133"))
|
| 216 |
+
self.reality_bridge = RealityPoincareBridge(
|
| 217 |
+
d_model=config.d_model,
|
| 218 |
+
d_reality=d_reality,
|
| 219 |
+
l0_k=int(os.environ.get("HYDRA_REALITY_L0_K", "64")),
|
| 220 |
+
)
|
| 221 |
+
if os.environ.get("HYDRA_CANTOR_DISABLE", "0") != "1":
|
| 222 |
+
self.cantor = CantorRouter(
|
| 223 |
+
depth=int(os.environ.get("HYDRA_CANTOR_DEPTH", "7")),
|
| 224 |
+
d_query=d_reality,
|
| 225 |
+
seed=int(os.environ.get("HYDRA_CANTOR_SEED", "42")),
|
| 226 |
+
device=self.wte.weight.device,
|
| 227 |
+
)
|
| 228 |
self.engram_layer_idx = config.engram_layer_idx
|
| 229 |
|
| 230 |
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
|
|
|
|
| 444 |
|
| 445 |
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
|
| 446 |
|
| 447 |
+
if hasattr(self.engram, "memory"):
|
| 448 |
+
nn.init.normal_(self.engram.memory, mean=0.0, std=0.01)
|
| 449 |
+
if hasattr(self.engram, "gate"):
|
| 450 |
+
nn.init.zeros_(self.engram.gate.weight)
|
| 451 |
+
nn.init.zeros_(self.engram.gate.bias)
|
| 452 |
+
if self.reality_bridge is not None:
|
| 453 |
+
nn.init.normal_(self.reality_bridge.to_reality.weight, mean=0.0, std=0.02)
|
| 454 |
+
nn.init.normal_(self.reality_bridge.to_tangent2.weight, mean=0.0, std=0.02)
|
| 455 |
+
if self.cantor is not None and hasattr(self.cantor, "branch"):
|
| 456 |
+
bound = (3.0 / float(self.cantor.d_query)) ** 0.5
|
| 457 |
+
nn.init.uniform_(self.cantor.branch, -bound, bound)
|
| 458 |
+
|
| 459 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 460 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 461 |
self.wte.to(dtype=torch.bfloat16)
|
| 462 |
self.blocks.to(dtype=torch.bfloat16)
|
| 463 |
self.htm_proj.to(dtype=torch.bfloat16)
|
| 464 |
self.engram.to(dtype=torch.bfloat16)
|
| 465 |
+
if self.reality_bridge is not None:
|
| 466 |
+
self.reality_bridge.to(dtype=torch.bfloat16)
|
| 467 |
+
if self.cantor is not None:
|
| 468 |
+
self.cantor.to(dtype=torch.bfloat16)
|
| 469 |
|
| 470 |
def set_bos_token_id(self, bos_id: int) -> None:
|
| 471 |
"""Inform the model of the tokenizer's BOS id so doc-separator
|
|
|
|
| 817 |
# HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
|
| 818 |
_htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
|
| 819 |
if not hasattr(self, '_htm_call_idx'):
|
| 820 |
+
self._htm_call_idx = int(os.environ.get("HYDRA_HTM_INITIAL_OFFSET", "0"))
|
| 821 |
|
| 822 |
_run_htm = (self._htm_call_idx % _htm_sub == 0)
|
| 823 |
self._htm_call_idx += 1
|
| 824 |
|
| 825 |
if _run_htm:
|
| 826 |
+
htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype)
|
| 827 |
else:
|
| 828 |
htm_handle = None
|
| 829 |
|
| 830 |
if _profile: _t_htm_async = _ev()
|
| 831 |
|
| 832 |
dense_emb = self.wte(idx) # (B, T, d_model) bf16
|
| 833 |
+
dense_emb = semantic_gaussian_mollify(
|
| 834 |
+
dense_emb,
|
| 835 |
+
std=float(os.environ.get("HYDRA_SEMANTIC_SMOOTH_STD", "0.0")),
|
| 836 |
+
training=self.training,
|
| 837 |
+
eval_enabled=os.environ.get("HYDRA_SEMANTIC_SMOOTH_EVAL", "0") == "1",
|
| 838 |
+
)
|
| 839 |
|
| 840 |
if _profile: _t_wte = _ev()
|
| 841 |
|
|
|
|
| 872 |
and htm_cache_matches(self._htm_cache_key, sdr_binary.nonzero())
|
| 873 |
):
|
| 874 |
htm_out = self._htm_cache
|
| 875 |
+
elif (
|
| 876 |
+
os.environ.get("HYDRA_HTM_ZERO_CACHE_ON_MISS", "0") == "1"
|
| 877 |
+
and self.training
|
| 878 |
+
and not self._mdlm_active
|
| 879 |
+
):
|
| 880 |
+
htm_out = torch.zeros((B, T, self.config.htm_n_columns + 1), device=dense_emb.device, dtype=dense_emb.dtype)
|
| 881 |
+
self._htm_cache = htm_out.detach()
|
| 882 |
+
self._htm_cache_key = None
|
| 883 |
+
self._htm_cache_shape = (B, T)
|
| 884 |
else:
|
| 885 |
# Very first call with subsample > 1, OR MDLM is on, OR the SDR
|
| 886 |
# pattern has changed from the cached one under exact mode: run HTM.
|
| 887 |
+
htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype)
|
| 888 |
htm_out = self.htm.forward_await(htm_handle)
|
| 889 |
self._htm_cache = htm_out.detach()
|
| 890 |
self._htm_cache_key = htm_cache_key(sdr_binary.nonzero())
|
|
|
|
| 957 |
# tensor of shape (n_streams, B, T, d_model) — see
|
| 958 |
# subsystems/mhc_mini.ManifoldHyperConnection.
|
| 959 |
x_mid = mhc_layer.merge_streams(streams)
|
| 960 |
+
if self.reality_bridge is not None and self.cantor is not None:
|
| 961 |
+
rb = self.reality_bridge(x_mid)
|
| 962 |
+
cantor_leaf_ids, _ = self.cantor(rb.reality, return_scores=False)
|
| 963 |
+
x_after_engram, hit_rate = self.engram(
|
| 964 |
+
x_mid,
|
| 965 |
+
idx,
|
| 966 |
+
sdr_active_indices=rb.l0_indices,
|
| 967 |
+
cantor_leaf_ids=cantor_leaf_ids,
|
| 968 |
+
cantor_n_leaves=self.cantor.n_leaves,
|
| 969 |
+
)
|
| 970 |
+
else:
|
| 971 |
+
x_after_engram, hit_rate = self.engram(x_mid, idx)
|
| 972 |
if os.environ.get("HYDRA_ENGRAM_RESET_STREAMS", "0") == "1":
|
| 973 |
streams = mhc_layer.init_streams(x_after_engram)
|
| 974 |
else:
|
overlay/hydra/optimizer.py
CHANGED
|
@@ -144,62 +144,117 @@ class MuonAdamW(torch.optim.Optimizer):
|
|
| 144 |
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 145 |
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 146 |
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def _step_adamw(self, group):
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
for p in group['params']:
|
| 151 |
if p.grad is None:
|
| 152 |
continue
|
| 153 |
-
state = self.
|
| 154 |
-
if not state:
|
| 155 |
-
state['step'] = 0
|
| 156 |
-
state['exp_avg'] = torch.zeros_like(p)
|
| 157 |
-
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 158 |
-
if 'step_t' not in state:
|
| 159 |
-
# _fused_adamw_ wants a per-param float step tensor on-device.
|
| 160 |
-
state['step_t'] = torch.tensor(
|
| 161 |
-
float(state['step']), dtype=torch.float32, device=p.device
|
| 162 |
-
)
|
| 163 |
state['step'] += 1
|
|
|
|
|
|
|
| 164 |
params.append(p)
|
| 165 |
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
|
| 166 |
exp_avgs.append(state['exp_avg'])
|
| 167 |
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 168 |
-
state_steps.append(state['step_t'])
|
| 169 |
|
| 170 |
if not params:
|
| 171 |
return
|
| 172 |
|
| 173 |
-
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW and params[0].is_cuda:
|
| 174 |
-
# _fused_adamw_ needs uniform (device, dtype) within a call, so
|
| 175 |
-
# group by (device, dtype) — same pattern as PyTorch's own
|
| 176 |
-
# AdamW(fused=True) path (_group_tensors_by_device_and_dtype).
|
| 177 |
-
buckets = {}
|
| 178 |
-
for p, g, ea, es, st in zip(params, grads, exp_avgs, exp_avg_sqs, state_steps):
|
| 179 |
-
key = (p.device, p.dtype)
|
| 180 |
-
buckets.setdefault(key, ([], [], [], [], []))
|
| 181 |
-
b_p, b_g, b_ea, b_es, b_st = buckets[key]
|
| 182 |
-
b_p.append(p); b_g.append(g); b_ea.append(ea); b_es.append(es); b_st.append(st)
|
| 183 |
-
|
| 184 |
-
lr_f = float(group['lr'])
|
| 185 |
-
b1_f = float(group['betas'][0])
|
| 186 |
-
b2_f = float(group['betas'][1])
|
| 187 |
-
wd_f = float(group['weight_decay'])
|
| 188 |
-
eps_f = float(group['eps'])
|
| 189 |
-
for (_dev, _dt), (b_p, b_g, b_ea, b_es, b_st) in buckets.items():
|
| 190 |
-
torch._foreach_add_(b_st, 1.0)
|
| 191 |
-
torch._fused_adamw_(
|
| 192 |
-
b_p, b_g, b_ea, b_es,
|
| 193 |
-
[], # max_exp_avg_sqs unused (amsgrad=False)
|
| 194 |
-
b_st,
|
| 195 |
-
amsgrad=False,
|
| 196 |
-
lr=lr_f, beta1=b1_f, beta2=b2_f,
|
| 197 |
-
weight_decay=wd_f, eps=eps_f,
|
| 198 |
-
maximize=False,
|
| 199 |
-
grad_scale=None, found_inf=None,
|
| 200 |
-
)
|
| 201 |
-
return
|
| 202 |
-
|
| 203 |
# Fallback per-param path.
|
| 204 |
self._adamw_lr_t.fill_(group['lr'])
|
| 205 |
self._adamw_beta1_t.fill_(group['betas'][0])
|
|
@@ -213,15 +268,34 @@ class MuonAdamW(torch.optim.Optimizer):
|
|
| 213 |
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 214 |
|
| 215 |
def _step_muon(self, group):
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
if not params:
|
| 218 |
return
|
| 219 |
p = params[0]
|
| 220 |
state = self.state[p]
|
| 221 |
num_params = len(params)
|
| 222 |
shape, device, dtype = p.shape, p.device, p.dtype
|
| 223 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
|
|
|
| 225 |
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 226 |
if "second_momentum_buffer" not in state:
|
| 227 |
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
|
|
|
|
| 144 |
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 145 |
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 146 |
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 147 |
+
self._adamw_bucket_caches = {}
|
| 148 |
+
self._muon_params_caches = {}
|
| 149 |
+
|
| 150 |
+
def state_dict(self):
|
| 151 |
+
sd = super().state_dict()
|
| 152 |
+
# Transient fused-step caches and device step_t tensors must not enter
|
| 153 |
+
# checkpoints. step_t is recreated from scalar state['step'] lazily.
|
| 154 |
+
for st in sd.get("state", {}).values():
|
| 155 |
+
st.pop("step_t", None)
|
| 156 |
+
for group in sd.get("param_groups", []):
|
| 157 |
+
group.pop("_adamw_bucket_cache", None)
|
| 158 |
+
group.pop("_muon_params_cache", None)
|
| 159 |
+
return sd
|
| 160 |
+
|
| 161 |
+
def load_state_dict(self, state_dict):
|
| 162 |
+
for st in state_dict.get("state", {}).values():
|
| 163 |
+
st.pop("step_t", None)
|
| 164 |
+
for group in state_dict.get("param_groups", []):
|
| 165 |
+
group.pop("_adamw_bucket_cache", None)
|
| 166 |
+
group.pop("_muon_params_cache", None)
|
| 167 |
+
self._adamw_bucket_caches.clear()
|
| 168 |
+
self._muon_params_caches.clear()
|
| 169 |
+
return super().load_state_dict(state_dict)
|
| 170 |
+
|
| 171 |
+
def _ensure_adamw_state(self, p):
|
| 172 |
+
state = self.state[p]
|
| 173 |
+
if not state:
|
| 174 |
+
state['step'] = 0
|
| 175 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 176 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 177 |
+
if 'step_t' not in state:
|
| 178 |
+
# _fused_adamw_ wants a per-param float step tensor on-device.
|
| 179 |
+
state['step_t'] = torch.tensor(
|
| 180 |
+
float(state['step']), dtype=torch.float32, device=p.device
|
| 181 |
+
)
|
| 182 |
+
return state
|
| 183 |
+
|
| 184 |
+
def _adamw_cached_buckets(self, group):
|
| 185 |
+
"""Return stable (device,dtype) param buckets for fused AdamW.
|
| 186 |
+
|
| 187 |
+
Cache topology only. Optimizer state remains lazy for grad-bearing
|
| 188 |
+
params so unused/frozen tensors do not bloat checkpoints.
|
| 189 |
+
"""
|
| 190 |
+
params_tuple = tuple(group['params'])
|
| 191 |
+
cache = self._adamw_bucket_caches.get(id(group))
|
| 192 |
+
if cache is not None and cache.get('params_tuple') == params_tuple:
|
| 193 |
+
return cache['buckets']
|
| 194 |
+
|
| 195 |
+
buckets = {}
|
| 196 |
+
for p in params_tuple:
|
| 197 |
+
key = (p.device, p.dtype)
|
| 198 |
+
buckets.setdefault(key, {'params': []})
|
| 199 |
+
buckets[key]['params'].append(p)
|
| 200 |
+
self._adamw_bucket_caches[id(group)] = {'params_tuple': params_tuple, 'buckets': buckets}
|
| 201 |
+
return buckets
|
| 202 |
|
| 203 |
def _step_adamw(self, group):
|
| 204 |
+
if _HYDRA_FUSED_ADAMW and _HAS_FUSED_ADAMW:
|
| 205 |
+
# Mixed CPU/CUDA groups are unusual in Feather but skipping CPU
|
| 206 |
+
# grads would be a correctness bug; disable fused path in that case.
|
| 207 |
+
if not any(p.grad is not None and not p.is_cuda for p in group['params']):
|
| 208 |
+
buckets = self._adamw_cached_buckets(group)
|
| 209 |
+
lr_f = float(group['lr'])
|
| 210 |
+
b1_f = float(group['betas'][0])
|
| 211 |
+
b2_f = float(group['betas'][1])
|
| 212 |
+
wd_f = float(group['weight_decay'])
|
| 213 |
+
eps_f = float(group['eps'])
|
| 214 |
+
launched = False
|
| 215 |
+
for (_dev, _dt), bucket in buckets.items():
|
| 216 |
+
b_p = [p for p in bucket['params'] if p.grad is not None]
|
| 217 |
+
if not b_p or not b_p[0].is_cuda:
|
| 218 |
+
continue
|
| 219 |
+
b_g = [p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad for p in b_p]
|
| 220 |
+
b_ea, b_es, b_st = [], [], []
|
| 221 |
+
for p in b_p:
|
| 222 |
+
state = self._ensure_adamw_state(p)
|
| 223 |
+
state['step'] += 1
|
| 224 |
+
b_ea.append(state['exp_avg'])
|
| 225 |
+
b_es.append(state['exp_avg_sq'])
|
| 226 |
+
b_st.append(state['step_t'])
|
| 227 |
+
torch._foreach_add_(b_st, 1.0)
|
| 228 |
+
torch._fused_adamw_(
|
| 229 |
+
b_p, b_g, b_ea, b_es,
|
| 230 |
+
[], # max_exp_avg_sqs unused (amsgrad=False)
|
| 231 |
+
b_st,
|
| 232 |
+
amsgrad=False,
|
| 233 |
+
lr=lr_f, beta1=b1_f, beta2=b2_f,
|
| 234 |
+
weight_decay=wd_f, eps=eps_f,
|
| 235 |
+
maximize=False,
|
| 236 |
+
grad_scale=None, found_inf=None,
|
| 237 |
+
)
|
| 238 |
+
launched = True
|
| 239 |
+
if launched:
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
params, grads, exp_avgs, exp_avg_sqs = [], [], [], []
|
| 243 |
for p in group['params']:
|
| 244 |
if p.grad is None:
|
| 245 |
continue
|
| 246 |
+
state = self._ensure_adamw_state(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
state['step'] += 1
|
| 248 |
+
if 'step_t' in state:
|
| 249 |
+
state['step_t'].fill_(float(state['step']))
|
| 250 |
params.append(p)
|
| 251 |
grads.append(p.grad.to(p.dtype) if p.grad.dtype != p.dtype else p.grad)
|
| 252 |
exp_avgs.append(state['exp_avg'])
|
| 253 |
exp_avg_sqs.append(state['exp_avg_sq'])
|
|
|
|
| 254 |
|
| 255 |
if not params:
|
| 256 |
return
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Fallback per-param path.
|
| 259 |
self._adamw_lr_t.fill_(group['lr'])
|
| 260 |
self._adamw_beta1_t.fill_(group['betas'][0])
|
|
|
|
| 268 |
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 269 |
|
| 270 |
def _step_muon(self, group):
|
| 271 |
+
params_tuple = tuple(group['params'])
|
| 272 |
+
cache = self._muon_params_caches.get(id(group))
|
| 273 |
+
if cache is None or cache.get('params_tuple') != params_tuple:
|
| 274 |
+
cache = {'params_tuple': params_tuple, 'params': list(params_tuple)}
|
| 275 |
+
self._muon_params_caches[id(group)] = cache
|
| 276 |
+
params_all = cache['params']
|
| 277 |
+
# Common Feather path: all Muon matrix params receive grads every step.
|
| 278 |
+
# Preserve sparse/None-grad correctness by filtering only when needed.
|
| 279 |
+
if all(p.grad is not None for p in params_all):
|
| 280 |
+
params = params_all
|
| 281 |
+
else:
|
| 282 |
+
params = [p for p in params_all if p.grad is not None]
|
| 283 |
if not params:
|
| 284 |
return
|
| 285 |
p = params[0]
|
| 286 |
state = self.state[p]
|
| 287 |
num_params = len(params)
|
| 288 |
shape, device, dtype = p.shape, p.device, p.dtype
|
| 289 |
+
if (
|
| 290 |
+
"momentum_buffer" not in state
|
| 291 |
+
or state["momentum_buffer"].shape[0] != num_params
|
| 292 |
+
or tuple(state["momentum_buffer"].shape[1:]) != tuple(shape)
|
| 293 |
+
):
|
| 294 |
+
# If grad-bearing Muon params change (rare; usually all matrix params
|
| 295 |
+
# have grads), resize instead of crashing compiled Muon on a stale
|
| 296 |
+
# leading dimension. This preserves skip-None-grad semantics.
|
| 297 |
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 298 |
+
state.pop("second_momentum_buffer", None)
|
| 299 |
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 300 |
if "second_momentum_buffer" not in state:
|
| 301 |
# Shape must match v_mean = stacked_grads.square().mean(dim=red_dim, keepdim=True)
|
overlay/hydra/training.py
CHANGED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
import sys
|
| 10 |
import threading
|
| 11 |
import time
|
| 12 |
-
from dataclasses import asdict
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
import torch
|
|
@@ -103,6 +103,22 @@ _CONTRASTIVE_CTX_LEN = int(os.environ.get("HYDRA_CONTRASTIVE_CTX_LEN", "8"))
|
|
| 103 |
_CONTRASTIVE_N_PAIRS = int(os.environ.get("HYDRA_CONTRASTIVE_N_PAIRS", "256"))
|
| 104 |
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# ---------------------------------------------------------------------------
|
| 107 |
# Schedules
|
| 108 |
# ---------------------------------------------------------------------------
|
|
@@ -136,6 +152,7 @@ def save_ckpt(
|
|
| 136 |
*,
|
| 137 |
val_bpb: float | None = None,
|
| 138 |
) -> None:
|
|
|
|
| 139 |
try:
|
| 140 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 141 |
payload = {
|
|
@@ -289,7 +306,22 @@ def maybe_resume_ckpt(
|
|
| 289 |
def main() -> None:
|
| 290 |
t_start = time.time()
|
| 291 |
torch.manual_seed(SEED)
|
| 292 |
-
torch.cuda.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
# Precision / kernel-selection knobs for peak throughput on Ampere.
|
| 294 |
# - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
|
| 295 |
# - allow_tf32 : explicit for both matmul + cudnn paths
|
|
@@ -299,12 +331,6 @@ def main() -> None:
|
|
| 299 |
# over the first ~100 steps. Observed 2026-04-22 and confirmed by
|
| 300 |
# differential profiling. Default is now FALSE; set =1 only if you
|
| 301 |
# see a specific workload where benchmark helps sustained tps.
|
| 302 |
-
torch.set_float32_matmul_precision("high")
|
| 303 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 304 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 305 |
-
torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
|
| 306 |
-
device = torch.device("cuda")
|
| 307 |
-
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 308 |
|
| 309 |
# Streaming path skips prepare.py (which normally trains the tokenizer
|
| 310 |
# and builds the retina), so we must materialize both before model init.
|
|
@@ -435,7 +461,7 @@ def main() -> None:
|
|
| 435 |
)
|
| 436 |
_train_phase("dataloader_prefetch_start")
|
| 437 |
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
|
| 438 |
-
if step > 0 and os.environ.get("HYDRA_RESUME_SKIP_DATALOADER", "1") =
|
| 439 |
_skip_micro_batches = step * grad_accum_steps
|
| 440 |
print(f"[resume] fast-forwarding train stream micro_batches={_skip_micro_batches} step={step} grad_accum={grad_accum_steps}", flush=True)
|
| 441 |
for _skip_i in range(_skip_micro_batches):
|
|
@@ -469,13 +495,11 @@ def main() -> None:
|
|
| 469 |
_ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
|
| 470 |
_som_thread: threading.Thread | None = None
|
| 471 |
_hestia_thread: threading.Thread | None = None
|
| 472 |
-
_hestia_stream
|
| 473 |
-
torch.cuda.Stream() if _ASYNC_POSTPROCESS else None
|
| 474 |
-
)
|
| 475 |
|
| 476 |
# Hebbian retina mode — per-step on-GPU update, mutually exclusive with SOM.
|
| 477 |
# Activated by env HYDRA_HEBBIAN_RETINA=1 (default off).
|
| 478 |
-
_HEBBIAN_RETINA = os.environ.get("HYDRA_HEBBIAN_RETINA", "0") == "1"
|
| 479 |
_HEBBIAN_ALPHA = float(os.environ.get("HYDRA_HEBBIAN_ALPHA", "0.001"))
|
| 480 |
_prof = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
|
| 481 |
if _HEBBIAN_RETINA:
|
|
@@ -514,6 +538,32 @@ def main() -> None:
|
|
| 514 |
# default cadence) instead of every step.
|
| 515 |
nan_flag = torch.zeros((), device=device, dtype=torch.bool)
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
_first_step_marker_emitted = False
|
| 518 |
while True:
|
| 519 |
if not _first_step_marker_emitted:
|
|
@@ -608,18 +658,9 @@ def main() -> None:
|
|
| 608 |
|
| 609 |
# A10G Hyena fallback can produce finite forward loss but non-finite
|
| 610 |
# gradients through the guarded residual path on the next optimizer
|
| 611 |
-
# step.
|
| 612 |
-
#
|
| 613 |
-
|
| 614 |
-
# Fast GPU-native grad guard
|
| 615 |
-
if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
|
| 616 |
-
with torch.no_grad():
|
| 617 |
-
for p in model.parameters():
|
| 618 |
-
if p.grad is not None:
|
| 619 |
-
p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 620 |
-
|
| 621 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 622 |
-
optimizer.step()
|
| 623 |
if _prof:
|
| 624 |
torch.cuda.synchronize(); _t_opt = time.time()
|
| 625 |
|
|
|
|
| 9 |
import sys
|
| 10 |
import threading
|
| 11 |
import time
|
| 12 |
+
from dataclasses import asdict, fields
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
import torch
|
|
|
|
| 103 |
_CONTRASTIVE_N_PAIRS = int(os.environ.get("HYDRA_CONTRASTIVE_N_PAIRS", "256"))
|
| 104 |
|
| 105 |
|
| 106 |
+
def config_from_dict(payload: dict) -> PostSemClawConfig:
|
| 107 |
+
"""Rebuild PostSemClawConfig from a checkpoint payload dict.
|
| 108 |
+
|
| 109 |
+
Checkpoints can contain older configs without newer dataclass fields, or
|
| 110 |
+
future configs with unknown fields. Keep loading permissive, but normalize
|
| 111 |
+
tuple-backed topology fields so Hyena/GDN layer selections survive JSON or
|
| 112 |
+
pickle paths that turn tuples into lists.
|
| 113 |
+
"""
|
| 114 |
+
field_names = {field.name for field in fields(PostSemClawConfig)}
|
| 115 |
+
kwargs = {key: value for key, value in payload.items() if key in field_names}
|
| 116 |
+
for tuple_key in ("hyena_layers", "gdn_layers"):
|
| 117 |
+
if tuple_key in kwargs and kwargs[tuple_key] is not None:
|
| 118 |
+
kwargs[tuple_key] = tuple(kwargs[tuple_key])
|
| 119 |
+
return PostSemClawConfig(**kwargs)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
# ---------------------------------------------------------------------------
|
| 123 |
# Schedules
|
| 124 |
# ---------------------------------------------------------------------------
|
|
|
|
| 152 |
*,
|
| 153 |
val_bpb: float | None = None,
|
| 154 |
) -> None:
|
| 155 |
+
global _CKPT_WORKER_THREAD
|
| 156 |
try:
|
| 157 |
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 158 |
payload = {
|
|
|
|
| 306 |
def main() -> None:
|
| 307 |
t_start = time.time()
|
| 308 |
torch.manual_seed(SEED)
|
| 309 |
+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 310 |
+
device = torch.device(device_str)
|
| 311 |
+
if device_str == "cuda":
|
| 312 |
+
torch.cuda.manual_seed(SEED)
|
| 313 |
+
torch.set_float32_matmul_precision("high")
|
| 314 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 315 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 316 |
+
torch.backends.cudnn.benchmark = os.environ.get("HYDRA_CUDNN_BENCHMARK", "0") == "1"
|
| 317 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 318 |
+
else:
|
| 319 |
+
# CPU path: limit BLAS threads to avoid oversubscription with data workers.
|
| 320 |
+
_cpu_threads = int(os.environ.get("HYDRA_CPU_THREADS", str(min(os.cpu_count() or 4, 8))))
|
| 321 |
+
torch.set_num_threads(_cpu_threads)
|
| 322 |
+
print(f"[CPU] torch.set_num_threads={_cpu_threads}")
|
| 323 |
+
autocast_ctx = torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=False)
|
| 324 |
+
|
| 325 |
# Precision / kernel-selection knobs for peak throughput on Ampere.
|
| 326 |
# - high : matmul uses TF32 (Ampere's 10-bit mantissa accum) for fp32 ops
|
| 327 |
# - allow_tf32 : explicit for both matmul + cudnn paths
|
|
|
|
| 331 |
# over the first ~100 steps. Observed 2026-04-22 and confirmed by
|
| 332 |
# differential profiling. Default is now FALSE; set =1 only if you
|
| 333 |
# see a specific workload where benchmark helps sustained tps.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
# Streaming path skips prepare.py (which normally trains the tokenizer
|
| 336 |
# and builds the retina), so we must materialize both before model init.
|
|
|
|
| 461 |
)
|
| 462 |
_train_phase("dataloader_prefetch_start")
|
| 463 |
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, _current_seq_len, "train")
|
| 464 |
+
if step > 0 and os.environ.get("HYDRA_RESUME_SKIP_DATALOADER", "1") != "1":
|
| 465 |
_skip_micro_batches = step * grad_accum_steps
|
| 466 |
print(f"[resume] fast-forwarding train stream micro_batches={_skip_micro_batches} step={step} grad_accum={grad_accum_steps}", flush=True)
|
| 467 |
for _skip_i in range(_skip_micro_batches):
|
|
|
|
| 495 |
_ASYNC_POSTPROCESS = os.environ.get("HYDRA_ASYNC_POSTPROCESS", "1") == "1"
|
| 496 |
_som_thread: threading.Thread | None = None
|
| 497 |
_hestia_thread: threading.Thread | None = None
|
| 498 |
+
_hestia_stream = torch.cuda.Stream() if (_ASYNC_POSTPROCESS and device.type == "cuda") else None
|
|
|
|
|
|
|
| 499 |
|
| 500 |
# Hebbian retina mode — per-step on-GPU update, mutually exclusive with SOM.
|
| 501 |
# Activated by env HYDRA_HEBBIAN_RETINA=1 (default off).
|
| 502 |
+
_HEBBIAN_RETINA = device.type == "cuda" and os.environ.get("HYDRA_HEBBIAN_RETINA", "0") == "1"
|
| 503 |
_HEBBIAN_ALPHA = float(os.environ.get("HYDRA_HEBBIAN_ALPHA", "0.001"))
|
| 504 |
_prof = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
|
| 505 |
if _HEBBIAN_RETINA:
|
|
|
|
| 538 |
# default cadence) instead of every step.
|
| 539 |
nan_flag = torch.zeros((), device=device, dtype=torch.bool)
|
| 540 |
|
| 541 |
+
# Device-step fusion surface: cache the parameter walk once and keep the
|
| 542 |
+
# finite-grad guard + clipping + optimizer launch in one compact boundary.
|
| 543 |
+
# This avoids re-materializing model.parameters() twice per optimizer step
|
| 544 |
+
# and gives the A10G path a single toggleable fused-step block without
|
| 545 |
+
# pulling dataloader/checkpoint/logging CPU control flow into Dynamo.
|
| 546 |
+
_HYDRA_FUSED_DEVICE_STEP = os.environ.get("HYDRA_FUSED_DEVICE_STEP", "1") == "1"
|
| 547 |
+
_trainable_params = tuple(model.parameters())
|
| 548 |
+
|
| 549 |
+
def _finish_device_step():
|
| 550 |
+
if _HYDRA_FUSED_DEVICE_STEP:
|
| 551 |
+
if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
|
| 552 |
+
with torch.no_grad():
|
| 553 |
+
for _p in _trainable_params:
|
| 554 |
+
if _p.grad is not None:
|
| 555 |
+
_p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 556 |
+
torch.nn.utils.clip_grad_norm_(_trainable_params, max_norm=1.0)
|
| 557 |
+
optimizer.step()
|
| 558 |
+
return
|
| 559 |
+
if os.environ.get("HYDRA_GRAD_FINITE_GUARD", "1") == "1":
|
| 560 |
+
with torch.no_grad():
|
| 561 |
+
for _p in model.parameters():
|
| 562 |
+
if _p.grad is not None:
|
| 563 |
+
_p.grad.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 564 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 565 |
+
optimizer.step()
|
| 566 |
+
|
| 567 |
_first_step_marker_emitted = False
|
| 568 |
while True:
|
| 569 |
if not _first_step_marker_emitted:
|
|
|
|
| 658 |
|
| 659 |
# A10G Hyena fallback can produce finite forward loss but non-finite
|
| 660 |
# gradients through the guarded residual path on the next optimizer
|
| 661 |
+
# step. The fused device-step boundary scrubs, clips, and launches the
|
| 662 |
+
# optimizer without re-walking model.parameters() on every substage.
|
| 663 |
+
_finish_device_step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
if _prof:
|
| 665 |
torch.cuda.synchronize(); _t_opt = time.time()
|
| 666 |
|
overlay/kernels/__init__.py
ADDED
|
File without changes
|
overlay/kernels/cuda/decode_kernels.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* CuTe DSL decode kernels for Mamba-3 autoregressive generation.
|
| 3 |
+
*
|
| 4 |
+
* Phase 2: Optimized single-token SSM step for inference.
|
| 5 |
+
* Phase 1: Not needed (training only, no generation).
|
| 6 |
+
*
|
| 7 |
+
* Fuses: input_proj + conv_step + ssm_step + output_proj
|
| 8 |
+
* into a single kernel launch for minimal latency.
|
| 9 |
+
*/
|
| 10 |
+
// Stub: Phase 2 implementation
|
overlay/kernels/cuda/flashfftconv/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
overlay/kernels/cuda/flashfftconv/README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flashfftconv (vendored)
|
| 2 |
+
|
| 3 |
+
Vendored from https://github.com/HazyResearch/flash-fft-conv (Apache 2.0 license).
|
| 4 |
+
|
| 5 |
+
**Upstream commit:** see `UPSTREAM_COMMIT`.
|
| 6 |
+
|
| 7 |
+
## What this is
|
| 8 |
+
|
| 9 |
+
HazyResearch's Monarch-matrix-decomposition FFT convolution CUDA kernel. Provides a
|
| 10 |
+
drop-in replacement for `torch.fft.rfft + complex-mult + irfft` that runs ~2-3x
|
| 11 |
+
faster than cuFFT for the specific power-of-two lengths it supports (256, 512,
|
| 12 |
+
1024, 2048, 4096, 8192, ..., up to 4M).
|
| 13 |
+
|
| 14 |
+
In HYDRA, we use it to accelerate `subsystems/hyena_pure.fftconv_ref`. The
|
| 15 |
+
accelerated path is opt-in via `HYDRA_HYENA_FLASH_FFT=1`; default behavior is
|
| 16 |
+
unchanged (pure PyTorch fallback).
|
| 17 |
+
|
| 18 |
+
## How to build
|
| 19 |
+
|
| 20 |
+
The vendored tree contains:
|
| 21 |
+
- `flashfftconv/` — pure-Python wrappers (imports `monarch_cuda` CUDA extension)
|
| 22 |
+
- `csrc/` — CUDA source files and setup.py for the native extension
|
| 23 |
+
|
| 24 |
+
Build instructions:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv/csrc
|
| 28 |
+
|
| 29 |
+
# Edit `csrc/setup.py` first: change the cc_flag line to match your GPU arch
|
| 30 |
+
# (RTX 3060 = 8.6, A100 = 8.0, H100 = 9.0). Example for RTX 3060:
|
| 31 |
+
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86']
|
| 32 |
+
|
| 33 |
+
# Build with the local CUDA toolchain (must match your torch.version.cuda):
|
| 34 |
+
CUDA_HOME=/usr/local/cuda-12.1 .venv/bin/pip install -e .
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Then install the Python wrappers:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
cd /home/mikeb/work/feather/kernels/cuda/flashfftconv
|
| 41 |
+
.venv/bin/pip install -e .
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Runtime usage
|
| 45 |
+
|
| 46 |
+
Once installed, set `HYDRA_HYENA_FLASH_FFT=1` and training will use it.
|
| 47 |
+
`subsystems/hyena_pure.fftconv_ref` auto-detects via `try: import flashfftconv`
|
| 48 |
+
and falls back to pure PyTorch on import failure.
|
| 49 |
+
|
| 50 |
+
## Known caveats
|
| 51 |
+
|
| 52 |
+
- Seqlen must be a power of 2 AND in the supported set: {256, 512, 1024, 2048,
|
| 53 |
+
4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304}.
|
| 54 |
+
For HYDRA, `fft_size = 2 * seq_len` → seq_len in {128, 256, 512, 1024, 2048, ...}.
|
| 55 |
+
- dtype must be fp16 or bf16 (fp32 not supported).
|
| 56 |
+
- GPU arch must be compiled into the extension (see setup.py cc_flag).
|
| 57 |
+
- CUDA toolchain major.minor should match `torch.version.cuda` major (12.x ↔ 12.x).
|
overlay/kernels/cuda/flashfftconv/UPSTREAM_COMMIT
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
b8771028717f46d5b22cbb8e12833f35033d621b
|
overlay/kernels/cuda/flashfftconv/csrc/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.npy
|
| 2 |
+
*.json
|
| 3 |
+
*.png
|
| 4 |
+
|
| 5 |
+
*/*.npy
|
| 6 |
+
*/*.json
|
| 7 |
+
*/*.png
|
| 8 |
+
|
| 9 |
+
*.DS_Store
|
| 10 |
+
*/*.DS_Store
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly.h
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 9 |
+
#define CHECK_IS_HALF_OR_BFLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16, #x " must be float16 or bfloat16")
|
| 10 |
+
#define CHECK_INPUT(x) \
|
| 11 |
+
CHECK_CUDA(x); \
|
| 12 |
+
CHECK_CONTIGUOUS(x); \
|
| 13 |
+
CHECK_IS_HALF_OR_BFLOAT(x)
|
| 14 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
std::vector<torch::Tensor> butterfly_cuda(
|
| 18 |
+
torch::Tensor x,
|
| 19 |
+
torch::Tensor d_f_T,
|
| 20 |
+
torch::Tensor twiddle_factors_real,
|
| 21 |
+
torch::Tensor twiddle_factors_imag,
|
| 22 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 27 |
+
torch::Tensor x,
|
| 28 |
+
torch::Tensor d_f_T_real,
|
| 29 |
+
torch::Tensor d_f_T_imag,
|
| 30 |
+
torch::Tensor twiddle_factors_real,
|
| 31 |
+
torch::Tensor twiddle_factors_imag,
|
| 32 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 33 |
+
);
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
std::vector<torch::Tensor> butterfly_padded_cuda(
|
| 37 |
+
torch::Tensor x,
|
| 38 |
+
torch::Tensor d_f_T,
|
| 39 |
+
torch::Tensor twiddle_factors_real,
|
| 40 |
+
torch::Tensor twiddle_factors_imag,
|
| 41 |
+
int M,
|
| 42 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 43 |
+
);
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
|
| 47 |
+
torch::Tensor x,
|
| 48 |
+
torch::Tensor d_f_T_real,
|
| 49 |
+
torch::Tensor d_f_T_imag,
|
| 50 |
+
torch::Tensor twiddle_factors_real,
|
| 51 |
+
torch::Tensor twiddle_factors_imag,
|
| 52 |
+
int M,
|
| 53 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 54 |
+
);
|
| 55 |
+
|
| 56 |
+
torch::Tensor butterfly_ifft_cuda(
|
| 57 |
+
torch::Tensor x_real,
|
| 58 |
+
torch::Tensor x_imag,
|
| 59 |
+
torch::Tensor d_f_T,
|
| 60 |
+
torch::Tensor twiddle_factors_real,
|
| 61 |
+
torch::Tensor twiddle_factors_imag,
|
| 62 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 66 |
+
torch::Tensor x_real,
|
| 67 |
+
torch::Tensor x_imag,
|
| 68 |
+
torch::Tensor d_f_real,
|
| 69 |
+
torch::Tensor d_f_imag,
|
| 70 |
+
torch::Tensor twiddle_factors_real,
|
| 71 |
+
torch::Tensor twiddle_factors_imag,
|
| 72 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
torch::Tensor butterfly_ifft_padded_cuda(
|
| 76 |
+
torch::Tensor x_real,
|
| 77 |
+
torch::Tensor x_imag,
|
| 78 |
+
torch::Tensor d_f,
|
| 79 |
+
torch::Tensor twiddle_factors_real,
|
| 80 |
+
torch::Tensor twiddle_factors_imag,
|
| 81 |
+
int N,
|
| 82 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 83 |
+
);
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
torch::Tensor butterfly_ifft_padded_bf16_cuda(
|
| 87 |
+
torch::Tensor x_real,
|
| 88 |
+
torch::Tensor x_imag,
|
| 89 |
+
torch::Tensor d_f_real,
|
| 90 |
+
torch::Tensor d_f_imag,
|
| 91 |
+
torch::Tensor twiddle_factors_real,
|
| 92 |
+
torch::Tensor twiddle_factors_imag,
|
| 93 |
+
int N,
|
| 94 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 95 |
+
);
|
| 96 |
+
|
| 97 |
+
std::vector<torch::Tensor> butterfly(
|
| 98 |
+
torch::Tensor x,
|
| 99 |
+
torch::Tensor d_f_T,
|
| 100 |
+
torch::Tensor twiddle_factors_real,
|
| 101 |
+
torch::Tensor twiddle_factors_imag
|
| 102 |
+
){
|
| 103 |
+
CHECK_INPUT(x);
|
| 104 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 105 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
std::vector<torch::Tensor> butterfly_gated(
|
| 112 |
+
torch::Tensor x,
|
| 113 |
+
torch::Tensor d_f_T,
|
| 114 |
+
torch::Tensor twiddle_factors_real,
|
| 115 |
+
torch::Tensor twiddle_factors_imag,
|
| 116 |
+
torch::Tensor x_gate
|
| 117 |
+
){
|
| 118 |
+
CHECK_INPUT(x);
|
| 119 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 120 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 121 |
+
|
| 122 |
+
CHECK_INPUT(x_gate);
|
| 123 |
+
|
| 124 |
+
return butterfly_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
std::vector<torch::Tensor> butterfly_bf16(
|
| 128 |
+
torch::Tensor x,
|
| 129 |
+
torch::Tensor d_f_T_real,
|
| 130 |
+
torch::Tensor d_f_T_imag,
|
| 131 |
+
torch::Tensor twiddle_factors_real,
|
| 132 |
+
torch::Tensor twiddle_factors_imag
|
| 133 |
+
){
|
| 134 |
+
CHECK_INPUT(x);
|
| 135 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 136 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 137 |
+
CHECK_INPUT(d_f_T_real);
|
| 138 |
+
CHECK_INPUT(d_f_T_imag);
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
std::vector<torch::Tensor> butterfly_gated_bf16(
|
| 145 |
+
torch::Tensor x,
|
| 146 |
+
torch::Tensor d_f_T_real,
|
| 147 |
+
torch::Tensor d_f_T_imag,
|
| 148 |
+
torch::Tensor twiddle_factors_real,
|
| 149 |
+
torch::Tensor twiddle_factors_imag,
|
| 150 |
+
torch::Tensor x_gate
|
| 151 |
+
){
|
| 152 |
+
CHECK_INPUT(x);
|
| 153 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 154 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 155 |
+
CHECK_INPUT(d_f_T_real);
|
| 156 |
+
CHECK_INPUT(d_f_T_imag);
|
| 157 |
+
CHECK_INPUT(x_gate);
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
return butterfly_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, x_gate);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
torch::Tensor butterfly_ifft(
|
| 164 |
+
torch::Tensor x_real,
|
| 165 |
+
torch::Tensor x_imag,
|
| 166 |
+
torch::Tensor d_f_T,
|
| 167 |
+
torch::Tensor twiddle_factors_real,
|
| 168 |
+
torch::Tensor twiddle_factors_imag
|
| 169 |
+
){
|
| 170 |
+
CHECK_INPUT(x_real);
|
| 171 |
+
CHECK_INPUT(x_imag);
|
| 172 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 173 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 174 |
+
|
| 175 |
+
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
torch::Tensor butterfly_ifft_gated(
|
| 180 |
+
torch::Tensor x_real,
|
| 181 |
+
torch::Tensor x_imag,
|
| 182 |
+
torch::Tensor d_f_T,
|
| 183 |
+
torch::Tensor twiddle_factors_real,
|
| 184 |
+
torch::Tensor twiddle_factors_imag,
|
| 185 |
+
torch::Tensor out_gate
|
| 186 |
+
){
|
| 187 |
+
CHECK_INPUT(x_real);
|
| 188 |
+
CHECK_INPUT(x_imag);
|
| 189 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 190 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 191 |
+
CHECK_INPUT(out_gate);
|
| 192 |
+
|
| 193 |
+
return butterfly_ifft_cuda(x_real, x_imag, d_f_T, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
torch::Tensor butterfly_ifft_bf16(
|
| 197 |
+
torch::Tensor x_real,
|
| 198 |
+
torch::Tensor x_imag,
|
| 199 |
+
torch::Tensor d_f_real,
|
| 200 |
+
torch::Tensor d_f_imag,
|
| 201 |
+
torch::Tensor twiddle_factors_real,
|
| 202 |
+
torch::Tensor twiddle_factors_imag
|
| 203 |
+
){
|
| 204 |
+
CHECK_INPUT(x_real);
|
| 205 |
+
CHECK_INPUT(x_imag);
|
| 206 |
+
CHECK_INPUT(d_f_real);
|
| 207 |
+
CHECK_INPUT(d_f_imag);
|
| 208 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 209 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag);
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
torch::Tensor butterfly_ifft_gated_bf16(
|
| 217 |
+
torch::Tensor x_real,
|
| 218 |
+
torch::Tensor x_imag,
|
| 219 |
+
torch::Tensor d_f_real,
|
| 220 |
+
torch::Tensor d_f_imag,
|
| 221 |
+
torch::Tensor twiddle_factors_real,
|
| 222 |
+
torch::Tensor twiddle_factors_imag,
|
| 223 |
+
torch::Tensor out_gate
|
| 224 |
+
){
|
| 225 |
+
CHECK_INPUT(x_real);
|
| 226 |
+
CHECK_INPUT(x_imag);
|
| 227 |
+
CHECK_INPUT(d_f_real);
|
| 228 |
+
CHECK_INPUT(d_f_imag);
|
| 229 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 230 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 231 |
+
CHECK_INPUT(out_gate);
|
| 232 |
+
|
| 233 |
+
return butterfly_ifft_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, out_gate);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
std::vector<torch::Tensor> butterfly_padded(
|
| 237 |
+
torch::Tensor x,
|
| 238 |
+
torch::Tensor d_f_T,
|
| 239 |
+
torch::Tensor twiddle_factors_real,
|
| 240 |
+
torch::Tensor twiddle_factors_imag,
|
| 241 |
+
int M
|
| 242 |
+
){
|
| 243 |
+
CHECK_INPUT(x);
|
| 244 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 245 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
std::vector<torch::Tensor> butterfly_padded_bf16(
|
| 252 |
+
torch::Tensor x,
|
| 253 |
+
torch::Tensor d_f_T_real,
|
| 254 |
+
torch::Tensor d_f_T_imag,
|
| 255 |
+
torch::Tensor twiddle_factors_real,
|
| 256 |
+
torch::Tensor twiddle_factors_imag,
|
| 257 |
+
int M
|
| 258 |
+
){
|
| 259 |
+
CHECK_INPUT(x);
|
| 260 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 261 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
std::vector<torch::Tensor> butterfly_padded_gated(
|
| 269 |
+
torch::Tensor x,
|
| 270 |
+
torch::Tensor d_f_T,
|
| 271 |
+
torch::Tensor twiddle_factors_real,
|
| 272 |
+
torch::Tensor twiddle_factors_imag,
|
| 273 |
+
int M,
|
| 274 |
+
torch::Tensor x_gate
|
| 275 |
+
){
|
| 276 |
+
CHECK_INPUT(x);
|
| 277 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 278 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
return butterfly_padded_cuda(x, d_f_T, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
std::vector<torch::Tensor> butterfly_padded_gated_bf16(
|
| 285 |
+
torch::Tensor x,
|
| 286 |
+
torch::Tensor d_f_T_real,
|
| 287 |
+
torch::Tensor d_f_T_imag,
|
| 288 |
+
torch::Tensor twiddle_factors_real,
|
| 289 |
+
torch::Tensor twiddle_factors_imag,
|
| 290 |
+
int M,
|
| 291 |
+
torch::Tensor x_gate
|
| 292 |
+
){
|
| 293 |
+
CHECK_INPUT(x);
|
| 294 |
+
CHECK_INPUT(d_f_T_real);
|
| 295 |
+
CHECK_INPUT(d_f_T_imag);
|
| 296 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 297 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
return butterfly_padded_bf16_cuda(x, d_f_T_real, d_f_T_imag, twiddle_factors_real, twiddle_factors_imag, M, x_gate);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
torch::Tensor butterfly_ifft_padded(
|
| 304 |
+
torch::Tensor x_real,
|
| 305 |
+
torch::Tensor x_imag,
|
| 306 |
+
torch::Tensor d_f,
|
| 307 |
+
torch::Tensor twiddle_factors_real,
|
| 308 |
+
torch::Tensor twiddle_factors_imag,
|
| 309 |
+
int N
|
| 310 |
+
){
|
| 311 |
+
CHECK_INPUT(x_real);
|
| 312 |
+
CHECK_INPUT(x_imag);
|
| 313 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 314 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 315 |
+
|
| 316 |
+
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N);
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
torch::Tensor butterfly_ifft_padded_gated(
|
| 320 |
+
torch::Tensor x_real,
|
| 321 |
+
torch::Tensor x_imag,
|
| 322 |
+
torch::Tensor d_f,
|
| 323 |
+
torch::Tensor twiddle_factors_real,
|
| 324 |
+
torch::Tensor twiddle_factors_imag,
|
| 325 |
+
int N,
|
| 326 |
+
torch::Tensor out_gate
|
| 327 |
+
){
|
| 328 |
+
CHECK_INPUT(x_real);
|
| 329 |
+
CHECK_INPUT(x_imag);
|
| 330 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 331 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 332 |
+
|
| 333 |
+
return butterfly_ifft_padded_cuda(x_real, x_imag, d_f, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
torch::Tensor butterfly_ifft_padded_bf16(
|
| 338 |
+
torch::Tensor x_real,
|
| 339 |
+
torch::Tensor x_imag,
|
| 340 |
+
torch::Tensor d_f_real,
|
| 341 |
+
torch::Tensor d_f_imag,
|
| 342 |
+
torch::Tensor twiddle_factors_real,
|
| 343 |
+
torch::Tensor twiddle_factors_imag,
|
| 344 |
+
int N
|
| 345 |
+
){
|
| 346 |
+
CHECK_INPUT(x_real);
|
| 347 |
+
CHECK_INPUT(x_imag);
|
| 348 |
+
CHECK_INPUT(d_f_real);
|
| 349 |
+
CHECK_INPUT(d_f_imag);
|
| 350 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 351 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 352 |
+
|
| 353 |
+
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
torch::Tensor butterfly_ifft_padded_gated_bf16(
|
| 357 |
+
torch::Tensor x_real,
|
| 358 |
+
torch::Tensor x_imag,
|
| 359 |
+
torch::Tensor d_f_real,
|
| 360 |
+
torch::Tensor d_f_imag,
|
| 361 |
+
torch::Tensor twiddle_factors_real,
|
| 362 |
+
torch::Tensor twiddle_factors_imag,
|
| 363 |
+
int N,
|
| 364 |
+
torch::Tensor out_gate
|
| 365 |
+
){
|
| 366 |
+
CHECK_INPUT(x_real);
|
| 367 |
+
CHECK_INPUT(x_imag);
|
| 368 |
+
CHECK_INPUT(d_f_real);
|
| 369 |
+
CHECK_INPUT(d_f_imag);
|
| 370 |
+
CHECK_INPUT(twiddle_factors_real);
|
| 371 |
+
CHECK_INPUT(twiddle_factors_imag);
|
| 372 |
+
|
| 373 |
+
return butterfly_ifft_padded_bf16_cuda(x_real, x_imag, d_f_real, d_f_imag, twiddle_factors_real, twiddle_factors_imag, N, out_gate);
|
| 374 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda.cu
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
__global__ void butterfly_cuda_kernel_64(
|
| 15 |
+
const __half2 *__restrict__ x,
|
| 16 |
+
const __half2 *__restrict__ x_gate,
|
| 17 |
+
const complex_half_t *__restrict__ d_f,
|
| 18 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
+
__half2 *__restrict__ out_real,
|
| 21 |
+
__half2 *__restrict__ out_imag,
|
| 22 |
+
uint B,
|
| 23 |
+
uint H,
|
| 24 |
+
int N)
|
| 25 |
+
{
|
| 26 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
+
int idx;
|
| 29 |
+
int shared_offset;
|
| 30 |
+
const int B_Y = blockDim.y;
|
| 31 |
+
const int n = N / B_Y;
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
extern __shared__ half x_shared[];
|
| 35 |
+
half *d_f_real = &x_shared[N * N];
|
| 36 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
+
half *out_imag_shared = &out_real_shared[N * N];
|
| 41 |
+
|
| 42 |
+
// #pragma unroll
|
| 43 |
+
for (int i = 0; i < n; i++)
|
| 44 |
+
{
|
| 45 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 46 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 47 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 48 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 49 |
+
|
| 50 |
+
// #pragma unroll
|
| 51 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 52 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 53 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 54 |
+
|
| 55 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 56 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
__half2 tmp_real, tmp_imag;
|
| 60 |
+
|
| 61 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
|
| 62 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 63 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 64 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
|
| 65 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[4][4];
|
| 66 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 67 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
|
| 68 |
+
|
| 69 |
+
__syncthreads();
|
| 70 |
+
|
| 71 |
+
for (int i = 0; i < 4; i++)
|
| 72 |
+
{
|
| 73 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
for (int t = 0; t < 16; t++)
|
| 80 |
+
{
|
| 81 |
+
|
| 82 |
+
for (int i = 0; i < n; i++)
|
| 83 |
+
{
|
| 84 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 85 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 86 |
+
if(x_gate != nullptr){
|
| 87 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 88 |
+
}else{
|
| 89 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
__syncthreads();
|
| 94 |
+
|
| 95 |
+
for (int i = 0; i < 4; i++)
|
| 96 |
+
{
|
| 97 |
+
for (int j = 0; j < 4; j++)
|
| 98 |
+
{
|
| 99 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
#pragma unroll
|
| 104 |
+
for (int j = 0; j < 4; j++)
|
| 105 |
+
{
|
| 106 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 107 |
+
|
| 108 |
+
for (int k = 0; k < 4; k++)
|
| 109 |
+
{
|
| 110 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
#pragma unroll
|
| 115 |
+
|
| 116 |
+
for (int j = 0; j < 4; j++)
|
| 117 |
+
{
|
| 118 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 119 |
+
|
| 120 |
+
for (int k = 0; k < 4; k++)
|
| 121 |
+
{
|
| 122 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int j = 0; j < 4; j++)
|
| 128 |
+
{
|
| 129 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 130 |
+
{
|
| 131 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 132 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 133 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 134 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
__syncthreads();
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < n; i++)
|
| 145 |
+
{
|
| 146 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 148 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
+
const __half2 *__restrict__ x,
|
| 157 |
+
const __half2 *__restrict__ x_gate,
|
| 158 |
+
const complex_half_t *__restrict__ d_f,
|
| 159 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 160 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 161 |
+
__half2 *__restrict__ out_real,
|
| 162 |
+
__half2 *__restrict__ out_imag,
|
| 163 |
+
uint B,
|
| 164 |
+
uint H,
|
| 165 |
+
int N)
|
| 166 |
+
{
|
| 167 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 168 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 169 |
+
int idx;
|
| 170 |
+
|
| 171 |
+
int shared_offset;
|
| 172 |
+
const int B_Y = blockDim.y;
|
| 173 |
+
const int n = N / B_Y;
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
__shared__ half x_shared[32 * 64];
|
| 177 |
+
__shared__ half d_f_real[32 * 32];
|
| 178 |
+
__shared__ half d_f_imag[32 * 32];
|
| 179 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 180 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 181 |
+
__shared__ half out_real_shared[32 * 64];
|
| 182 |
+
__shared__ half out_imag_shared[32 * 64];
|
| 183 |
+
|
| 184 |
+
// #pragma unroll
|
| 185 |
+
for (int i = 0; i < n; i++)
|
| 186 |
+
{
|
| 187 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 188 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 189 |
+
if(x_gate == nullptr){
|
| 190 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 191 |
+
}else{
|
| 192 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 193 |
+
}
|
| 194 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
+
|
| 197 |
+
// #pragma unroll
|
| 198 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 199 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
__syncthreads();
|
| 203 |
+
|
| 204 |
+
if (threadIdx.y < N / 16)
|
| 205 |
+
{
|
| 206 |
+
__half2 tmp_real, tmp_imag;
|
| 207 |
+
|
| 208 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 209 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 210 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[2][2];
|
| 213 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 214 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
|
| 215 |
+
|
| 216 |
+
int t = threadIdx.y * 32;
|
| 217 |
+
|
| 218 |
+
for (int i = 0; i < 2; i++)
|
| 219 |
+
{
|
| 220 |
+
for (int j = 0; j < 2; j++)
|
| 221 |
+
{
|
| 222 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 223 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 224 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 225 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
#pragma unroll
|
| 231 |
+
for (int i = 0; i < 2; i++)
|
| 232 |
+
{
|
| 233 |
+
for (int j = 0; j < 2; j++)
|
| 234 |
+
{
|
| 235 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 236 |
+
|
| 237 |
+
for (int k = 0; k < 2; k++)
|
| 238 |
+
{
|
| 239 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
#pragma unroll
|
| 245 |
+
for (int i = 0; i < 2; i++)
|
| 246 |
+
{
|
| 247 |
+
for (int j = 0; j < 2; j++)
|
| 248 |
+
{
|
| 249 |
+
wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
|
| 250 |
+
|
| 251 |
+
for (int k = 0; k < 2; k++)
|
| 252 |
+
{
|
| 253 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
#pragma unroll
|
| 259 |
+
for (int i = 0; i < 2; i++)
|
| 260 |
+
{
|
| 261 |
+
for (int j = 0; j < 2; j++)
|
| 262 |
+
{
|
| 263 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 264 |
+
{
|
| 265 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
|
| 266 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
|
| 267 |
+
reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
|
| 268 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
__syncthreads();
|
| 278 |
+
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int i = 0; i < n; i++)
|
| 281 |
+
{
|
| 282 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 284 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
+
const __half2 *__restrict__ x,
|
| 290 |
+
const __half2 *__restrict__ x_gate,
|
| 291 |
+
const complex_half_t *__restrict__ d_f,
|
| 292 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 293 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 294 |
+
__half2 *__restrict__ out_real,
|
| 295 |
+
__half2 *__restrict__ out_imag,
|
| 296 |
+
uint B,
|
| 297 |
+
uint H,
|
| 298 |
+
int N)
|
| 299 |
+
{
|
| 300 |
+
const int offset = blockIdx.y * H * 128 * 32 * gridDim.x * 2 + blockIdx.z * 16 * 128 * 32 * gridDim.x * 2 + blockIdx.x * 64 + threadIdx.x;
|
| 301 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 302 |
+
int idx;
|
| 303 |
+
|
| 304 |
+
int shared_offset;
|
| 305 |
+
const int B_Y = blockDim.y;
|
| 306 |
+
const int n = N / B_Y;
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
extern __shared__ half shared_real[];
|
| 310 |
+
half *shared_imag = &shared_real[128 * 128];
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
|
| 314 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 315 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 316 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
|
| 317 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[8][8];
|
| 318 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 319 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
|
| 320 |
+
|
| 321 |
+
for (int i = 0; i < n; i++)
|
| 322 |
+
{
|
| 323 |
+
for(int j=0; j< 4; j++){
|
| 324 |
+
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 325 |
+
shared_real[shared_offset] = d_f[shared_offset].real();
|
| 326 |
+
shared_imag[shared_offset] = d_f[shared_offset].imag();
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
__syncthreads();
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
for (int i = 0; i < 8; i++){
|
| 334 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 335 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
__syncthreads();
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
for (int i = 0; i < n; i++)
|
| 344 |
+
{
|
| 345 |
+
for(int j=0; j< 2; j++){
|
| 346 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 347 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 348 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 349 |
+
reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
__syncthreads();
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
for (int i = 0; i < 8; i++){
|
| 357 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 358 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
__syncthreads();
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
for(int t=0; t< 16; t++){
|
| 365 |
+
for (int i = 0; i < n; i++)
|
| 366 |
+
{
|
| 367 |
+
for(int j=0; j< 2; j++){
|
| 368 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 369 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 370 |
+
if(x_gate != nullptr){
|
| 371 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 372 |
+
}else{
|
| 373 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = x[offset + idx];
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
__syncthreads();
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
for (int i = 0; i < 8; i++)
|
| 384 |
+
{
|
| 385 |
+
for (int j = 0; j < 8; j++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
__syncthreads();
|
| 392 |
+
|
| 393 |
+
#pragma unroll
|
| 394 |
+
for (int j = 0; j < 8; j++)
|
| 395 |
+
{
|
| 396 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 397 |
+
|
| 398 |
+
for (int k = 0; k < 8; k++)
|
| 399 |
+
{
|
| 400 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
#pragma unroll
|
| 405 |
+
|
| 406 |
+
for (int j = 0; j < 8; j++)
|
| 407 |
+
{
|
| 408 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 409 |
+
|
| 410 |
+
for (int k = 0; k < 8; k++)
|
| 411 |
+
{
|
| 412 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
__half2 tmp_real, tmp_imag;
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (int j = 0; j < 8; j++)
|
| 419 |
+
{
|
| 420 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
+
{
|
| 422 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 423 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 425 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 429 |
+
wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
__syncthreads();
|
| 433 |
+
|
| 434 |
+
#pragma unroll
|
| 435 |
+
for (int i = 0; i < n; i++)
|
| 436 |
+
{
|
| 437 |
+
for(int j=0; j< 2; j++){
|
| 438 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 439 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 440 |
+
out_real[offset + idx] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
|
| 441 |
+
out_imag[offset + idx] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
__syncthreads();
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
__global__ void butterfly_cuda_kernel_16(
|
| 451 |
+
const __half2 *__restrict__ x,
|
| 452 |
+
const __half2 *__restrict__ x_gate,
|
| 453 |
+
const complex_half_t *__restrict__ d_f,
|
| 454 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 455 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 456 |
+
__half2 *__restrict__ out_real,
|
| 457 |
+
__half2 *__restrict__ out_imag,
|
| 458 |
+
uint B,
|
| 459 |
+
uint H,
|
| 460 |
+
int N)
|
| 461 |
+
{
|
| 462 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 463 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 464 |
+
int idx;
|
| 465 |
+
|
| 466 |
+
int shared_offset;
|
| 467 |
+
const int B_Y = blockDim.y;
|
| 468 |
+
const int n = N / B_Y;
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
__shared__ half x_shared[16 * 64];
|
| 472 |
+
__shared__ half d_f_real[16 * 16];
|
| 473 |
+
__shared__ half d_f_imag[16 * 16];
|
| 474 |
+
__shared__ half twiddles_real_shared[16 * 64];
|
| 475 |
+
__shared__ half twiddles_imag_shared[16 * 64];
|
| 476 |
+
__shared__ half out_real_shared[16 * 64];
|
| 477 |
+
__shared__ half out_imag_shared[16 * 64];
|
| 478 |
+
|
| 479 |
+
// #pragma unroll
|
| 480 |
+
for (int i = 0; i < n; i++)
|
| 481 |
+
{
|
| 482 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 483 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 484 |
+
|
| 485 |
+
if(x_gate != NULL)
|
| 486 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 487 |
+
else
|
| 488 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 489 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 490 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 491 |
+
|
| 492 |
+
// #pragma unroll
|
| 493 |
+
|
| 494 |
+
if(threadIdx.x < 16 ){
|
| 495 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 496 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 497 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
__syncthreads();
|
| 502 |
+
|
| 503 |
+
if (threadIdx.y < 4)
|
| 504 |
+
{
|
| 505 |
+
__half2 tmp_real, tmp_imag;
|
| 506 |
+
|
| 507 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 508 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 509 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 510 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 511 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
|
| 512 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 513 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
|
| 514 |
+
|
| 515 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 516 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 517 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 518 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 519 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 536 |
+
{
|
| 537 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
|
| 538 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
|
| 539 |
+
reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
|
| 540 |
+
reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 544 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
__syncthreads();
|
| 548 |
+
|
| 549 |
+
#pragma unroll
|
| 550 |
+
for (int i = 0; i < n; i++)
|
| 551 |
+
{
|
| 552 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 553 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 554 |
+
out_imag[idx] = reinterpret_cast<__half2 *>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 555 |
+
}
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
std::vector<torch::Tensor> butterfly_cuda(
|
| 560 |
+
torch::Tensor x,
|
| 561 |
+
torch::Tensor d_f,
|
| 562 |
+
torch::Tensor twiddle_factors_real,
|
| 563 |
+
torch::Tensor twiddle_factors_imag,
|
| 564 |
+
std::optional<at::Tensor> x_gate = std::nullopt)
|
| 565 |
+
{
|
| 566 |
+
|
| 567 |
+
uint B = x.size(0);
|
| 568 |
+
uint H = x.size(1);
|
| 569 |
+
// uint m = x.size(1);
|
| 570 |
+
|
| 571 |
+
// const int TILE_SIZE = 16;
|
| 572 |
+
uint N = x.size(2);
|
| 573 |
+
uint M = x.size(3);
|
| 574 |
+
dim3 gridDim;
|
| 575 |
+
dim3 blockDim;
|
| 576 |
+
|
| 577 |
+
gridDim.y = B;
|
| 578 |
+
gridDim.z = H;
|
| 579 |
+
|
| 580 |
+
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 581 |
+
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 582 |
+
|
| 583 |
+
//set blockDims
|
| 584 |
+
switch(N){
|
| 585 |
+
case 128:
|
| 586 |
+
blockDim.x = 32;
|
| 587 |
+
blockDim.y = 8;
|
| 588 |
+
break;
|
| 589 |
+
default:
|
| 590 |
+
blockDim.x = 32;
|
| 591 |
+
blockDim.y = 4;
|
| 592 |
+
break;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
//set gridDim.x
|
| 596 |
+
switch(N){
|
| 597 |
+
case 128:
|
| 598 |
+
switch (M){
|
| 599 |
+
case 16384:
|
| 600 |
+
gridDim.x = 128;
|
| 601 |
+
break;
|
| 602 |
+
case 8192:
|
| 603 |
+
gridDim.x = 64;
|
| 604 |
+
break;
|
| 605 |
+
case 4096:
|
| 606 |
+
gridDim.x = 32;
|
| 607 |
+
break;
|
| 608 |
+
default:
|
| 609 |
+
gridDim.x = 256;
|
| 610 |
+
break;
|
| 611 |
+
}
|
| 612 |
+
break;
|
| 613 |
+
default:
|
| 614 |
+
switch (M){
|
| 615 |
+
case 16384:
|
| 616 |
+
gridDim.x = 256;
|
| 617 |
+
break;
|
| 618 |
+
case 8192:
|
| 619 |
+
gridDim.x = 128;
|
| 620 |
+
break;
|
| 621 |
+
case 4096:
|
| 622 |
+
gridDim.x = 64;
|
| 623 |
+
break;
|
| 624 |
+
default:
|
| 625 |
+
gridDim.x = 512;
|
| 626 |
+
break;
|
| 627 |
+
}
|
| 628 |
+
break;
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
switch (N)
|
| 632 |
+
{
|
| 633 |
+
case 16:
|
| 634 |
+
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 635 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 636 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 637 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 638 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 639 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 640 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 641 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 642 |
+
B,
|
| 643 |
+
H,
|
| 644 |
+
N);
|
| 645 |
+
break;
|
| 646 |
+
case 32:
|
| 647 |
+
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 648 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 649 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 650 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 651 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 652 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 653 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 654 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 655 |
+
B,
|
| 656 |
+
H,
|
| 657 |
+
N);
|
| 658 |
+
break;
|
| 659 |
+
|
| 660 |
+
case 64:
|
| 661 |
+
gridDim.z = H / 16;
|
| 662 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 663 |
+
|
| 664 |
+
butterfly_cuda_kernel_64<<<gridDim, blockDim, 57344>>>(
|
| 665 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 666 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 667 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 668 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 669 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 670 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 671 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 672 |
+
B,
|
| 673 |
+
H,
|
| 674 |
+
N);
|
| 675 |
+
break;
|
| 676 |
+
case 128:
|
| 677 |
+
gridDim.z = H / 16;
|
| 678 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 679 |
+
|
| 680 |
+
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 681 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 682 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 683 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 684 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 685 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 686 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 687 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 688 |
+
B,
|
| 689 |
+
H,
|
| 690 |
+
N);
|
| 691 |
+
break;
|
| 692 |
+
|
| 693 |
+
default:
|
| 694 |
+
printf("Not yet implemented \n");
|
| 695 |
+
break;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
return {out_real, out_imag};
|
| 699 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_cuda_bf16.cu
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_runtime.h>
|
| 9 |
+
#include <cuda_fp16.h>
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
__global__ void butterfly_cuda_kernel_64(
|
| 16 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 17 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 18 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int N)
|
| 27 |
+
{
|
| 28 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
+
int idx;
|
| 31 |
+
int shared_offset;
|
| 32 |
+
const int B_Y = blockDim.y;
|
| 33 |
+
const int n = N / B_Y;
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
extern __shared__ __nv_bfloat16 x_shared[];
|
| 37 |
+
__nv_bfloat16 *d_f_real_shared = &x_shared[N * N];
|
| 38 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
+
float *out_imag_shared = &out_real_shared[N * N];
|
| 43 |
+
|
| 44 |
+
// #pragma unroll
|
| 45 |
+
for (int i = 0; i < n; i++)
|
| 46 |
+
{
|
| 47 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 48 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 49 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 50 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 51 |
+
|
| 52 |
+
// #pragma unroll
|
| 53 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 54 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 55 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
float2 tmp_real, tmp_imag;
|
| 59 |
+
|
| 60 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
|
| 61 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 62 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 63 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
|
| 64 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
|
| 65 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 66 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
|
| 67 |
+
|
| 68 |
+
__syncthreads();
|
| 69 |
+
|
| 70 |
+
for (int i = 0; i < 4; i++)
|
| 71 |
+
{
|
| 72 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 73 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 74 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
for (int t = 0; t < 16; t++)
|
| 79 |
+
{
|
| 80 |
+
|
| 81 |
+
for (int i = 0; i < n; i++)
|
| 82 |
+
{
|
| 83 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 84 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 85 |
+
if(x_gate != nullptr){
|
| 86 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 87 |
+
}else{
|
| 88 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
for (int i = 0; i < 4; i++)
|
| 95 |
+
{
|
| 96 |
+
for (int j = 0; j < 4; j++)
|
| 97 |
+
{
|
| 98 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int j = 0; j < 4; j++)
|
| 104 |
+
{
|
| 105 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 106 |
+
|
| 107 |
+
for (int k = 0; k < 4; k++)
|
| 108 |
+
{
|
| 109 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
#pragma unroll
|
| 114 |
+
|
| 115 |
+
for (int j = 0; j < 4; j++)
|
| 116 |
+
{
|
| 117 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 118 |
+
|
| 119 |
+
for (int k = 0; k < 4; k++)
|
| 120 |
+
{
|
| 121 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (int j = 0; j < 4; j++)
|
| 127 |
+
{
|
| 128 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 129 |
+
{
|
| 130 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 131 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 132 |
+
|
| 133 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 134 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 138 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
__syncthreads();
|
| 142 |
+
|
| 143 |
+
#pragma unroll
|
| 144 |
+
for (int i = 0; i < n; i++)
|
| 145 |
+
{
|
| 146 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 147 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 148 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__global__ void butterfly_cuda_kernel_32(
|
| 156 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 157 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 158 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 159 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 160 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 161 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 162 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 163 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 164 |
+
uint B,
|
| 165 |
+
uint H,
|
| 166 |
+
int N)
|
| 167 |
+
{
|
| 168 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 169 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 170 |
+
int idx;
|
| 171 |
+
|
| 172 |
+
int shared_offset;
|
| 173 |
+
const int B_Y = blockDim.y;
|
| 174 |
+
const int n = N / B_Y;
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
__shared__ __nv_bfloat16 x_shared[32 * 64];
|
| 178 |
+
__shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
|
| 179 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
|
| 180 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 181 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 182 |
+
__shared__ float out_real_shared[32 * 64];
|
| 183 |
+
__shared__ float out_imag_shared[32 * 64];
|
| 184 |
+
|
| 185 |
+
// #pragma unroll
|
| 186 |
+
for (int i = 0; i < n; i++)
|
| 187 |
+
{
|
| 188 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 189 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 190 |
+
if(x_gate != nullptr){
|
| 191 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 192 |
+
}else{
|
| 193 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 194 |
+
}
|
| 195 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 196 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 197 |
+
|
| 198 |
+
// #pragma unroll
|
| 199 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 200 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
__syncthreads();
|
| 204 |
+
|
| 205 |
+
if (threadIdx.y < N / 16)
|
| 206 |
+
{
|
| 207 |
+
float2 tmp_real, tmp_imag;
|
| 208 |
+
|
| 209 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 210 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[2][2];
|
| 214 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 215 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
|
| 216 |
+
|
| 217 |
+
int t = threadIdx.y * 32;
|
| 218 |
+
|
| 219 |
+
for (int i = 0; i < 2; i++)
|
| 220 |
+
{
|
| 221 |
+
for (int j = 0; j < 2; j++)
|
| 222 |
+
{
|
| 223 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 224 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 225 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 226 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 227 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
#pragma unroll
|
| 232 |
+
for (int i = 0; i < 2; i++)
|
| 233 |
+
{
|
| 234 |
+
for (int j = 0; j < 2; j++)
|
| 235 |
+
{
|
| 236 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 237 |
+
|
| 238 |
+
for (int k = 0; k < 2; k++)
|
| 239 |
+
{
|
| 240 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
#pragma unroll
|
| 246 |
+
for (int i = 0; i < 2; i++)
|
| 247 |
+
{
|
| 248 |
+
for (int j = 0; j < 2; j++)
|
| 249 |
+
{
|
| 250 |
+
wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
|
| 251 |
+
|
| 252 |
+
for (int k = 0; k < 2; k++)
|
| 253 |
+
{
|
| 254 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for (int i = 0; i < 2; i++)
|
| 261 |
+
{
|
| 262 |
+
for (int j = 0; j < 2; j++)
|
| 263 |
+
{
|
| 264 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 265 |
+
{
|
| 266 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
|
| 267 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
|
| 268 |
+
reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
|
| 269 |
+
reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
|
| 270 |
+
}
|
| 271 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 272 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
__syncthreads();
|
| 278 |
+
|
| 279 |
+
#pragma unroll
|
| 280 |
+
for (int i = 0; i < n; i++)
|
| 281 |
+
{
|
| 282 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 283 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 284 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__global__ void butterfly_cuda_kernel_128(
|
| 289 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 290 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 291 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 292 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 293 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 294 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 295 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 296 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 297 |
+
uint B,
|
| 298 |
+
uint H,
|
| 299 |
+
int N)
|
| 300 |
+
{
|
| 301 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 302 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 303 |
+
int idx;
|
| 304 |
+
|
| 305 |
+
int shared_offset;
|
| 306 |
+
const int B_Y = blockDim.y;
|
| 307 |
+
const int n = N / B_Y;
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
extern __shared__ __nv_bfloat16 shared_real[];
|
| 311 |
+
__nv_bfloat16 *shared_imag = &shared_real[128 * 128];
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
|
| 315 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 316 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 317 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
|
| 318 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[8][8];
|
| 319 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 320 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
|
| 321 |
+
|
| 322 |
+
for (int i = 0; i < n; i++)
|
| 323 |
+
{
|
| 324 |
+
for(int j=0; j< 2; j++){
|
| 325 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 326 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
|
| 327 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
__syncthreads();
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
for (int i = 0; i < 8; i++){
|
| 335 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 336 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
__syncthreads();
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
for (int i = 0; i < n; i++)
|
| 345 |
+
{
|
| 346 |
+
for(int j=0; j< 2; j++){
|
| 347 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 348 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 349 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 350 |
+
reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
__syncthreads();
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
for (int i = 0; i < 8; i++){
|
| 358 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 359 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
__syncthreads();
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
for(int t=0; t< 16; t++){
|
| 366 |
+
for (int i = 0; i < n; i++)
|
| 367 |
+
{
|
| 368 |
+
for(int j=0; j< 2; j++){
|
| 369 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 370 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 371 |
+
if(x_gate != nullptr){
|
| 372 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 373 |
+
}else{
|
| 374 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = x[offset + idx];
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
__syncthreads();
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
for (int i = 0; i < 8; i++)
|
| 384 |
+
{
|
| 385 |
+
for (int j = 0; j < 8; j++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
__syncthreads();
|
| 392 |
+
|
| 393 |
+
#pragma unroll
|
| 394 |
+
for (int j = 0; j < 8; j++)
|
| 395 |
+
{
|
| 396 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 397 |
+
|
| 398 |
+
for (int k = 0; k < 8; k++)
|
| 399 |
+
{
|
| 400 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
#pragma unroll
|
| 405 |
+
|
| 406 |
+
for (int j = 0; j < 8; j++)
|
| 407 |
+
{
|
| 408 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 409 |
+
|
| 410 |
+
for (int k = 0; k < 8; k++)
|
| 411 |
+
{
|
| 412 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
float2 tmp_real, tmp_imag;
|
| 417 |
+
#pragma unroll
|
| 418 |
+
for (int j = 0; j < 8; j++)
|
| 419 |
+
{
|
| 420 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 421 |
+
{
|
| 422 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 423 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 424 |
+
|
| 425 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 426 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
for (int j = 0; j < 8; j++)
|
| 431 |
+
{
|
| 432 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
__syncthreads();
|
| 436 |
+
|
| 437 |
+
#pragma unroll
|
| 438 |
+
for (int i = 0; i < n; i++)
|
| 439 |
+
{
|
| 440 |
+
for(int j=0; j< 2; j++){
|
| 441 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 442 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 443 |
+
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
__syncthreads();
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
for (int j = 0; j < 8; j++)
|
| 451 |
+
{
|
| 452 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
__syncthreads();
|
| 456 |
+
|
| 457 |
+
#pragma unroll
|
| 458 |
+
for (int i = 0; i < n; i++)
|
| 459 |
+
{
|
| 460 |
+
for(int j=0; j< 2; j++){
|
| 461 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
+
out_imag[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
__global__ void butterfly_cuda_kernel_16(
|
| 471 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 472 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 473 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 474 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 475 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 476 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 477 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 478 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 479 |
+
uint B,
|
| 480 |
+
uint H,
|
| 481 |
+
int N)
|
| 482 |
+
{
|
| 483 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 484 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 485 |
+
int idx;
|
| 486 |
+
|
| 487 |
+
int shared_offset;
|
| 488 |
+
const int B_Y = blockDim.y;
|
| 489 |
+
const int n = N / B_Y;
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
__shared__ __nv_bfloat16 x_shared[16 * 64];
|
| 493 |
+
__shared__ __nv_bfloat16 d_f_real_shared[16 * 16];
|
| 494 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[16 * 16];
|
| 495 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 496 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 497 |
+
__shared__ float out_real_shared[16 * 64];
|
| 498 |
+
__shared__ float out_imag_shared[16 * 64];
|
| 499 |
+
|
| 500 |
+
// #pragma unroll
|
| 501 |
+
for (int i = 0; i < n; i++)
|
| 502 |
+
{
|
| 503 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 504 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 505 |
+
if(x_gate != nullptr){
|
| 506 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = __hmul2(x[idx + offset], x_gate[idx + offset]);
|
| 507 |
+
}else{
|
| 508 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = x[idx + offset];
|
| 509 |
+
}
|
| 510 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
+
|
| 513 |
+
// #pragma unroll
|
| 514 |
+
if(threadIdx.x < 16 ){
|
| 515 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 516 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 517 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
__syncthreads();
|
| 522 |
+
|
| 523 |
+
if (threadIdx.y < 4)
|
| 524 |
+
{
|
| 525 |
+
float2 tmp_real, tmp_imag;
|
| 526 |
+
|
| 527 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 528 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 529 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 530 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 531 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
|
| 532 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 533 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
|
| 534 |
+
|
| 535 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
|
| 536 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
|
| 537 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 538 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 539 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
wmma::fill_fragment(acc_frag_imag, 0.0f);
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
#pragma unroll
|
| 557 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 558 |
+
{
|
| 559 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
|
| 560 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
|
| 561 |
+
reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
|
| 562 |
+
reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
|
| 563 |
+
}
|
| 564 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 566 |
+
|
| 567 |
+
}
|
| 568 |
+
__syncthreads();
|
| 569 |
+
|
| 570 |
+
#pragma unroll
|
| 571 |
+
for (int i = 0; i < n; i++)
|
| 572 |
+
{
|
| 573 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 575 |
+
out_imag[idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
std::vector<torch::Tensor> butterfly_bf16_cuda(
|
| 580 |
+
torch::Tensor x,
|
| 581 |
+
torch::Tensor d_f_real,
|
| 582 |
+
torch::Tensor d_f_imag,
|
| 583 |
+
torch::Tensor twiddle_factors_real,
|
| 584 |
+
torch::Tensor twiddle_factors_imag,
|
| 585 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 586 |
+
)
|
| 587 |
+
{
|
| 588 |
+
|
| 589 |
+
uint B = x.size(0);
|
| 590 |
+
uint H = x.size(1);
|
| 591 |
+
// uint m = x.size(1);
|
| 592 |
+
|
| 593 |
+
// const int TILE_SIZE = 16;
|
| 594 |
+
uint N = x.size(2);
|
| 595 |
+
uint M = x.size(3);
|
| 596 |
+
dim3 gridDim;
|
| 597 |
+
dim3 blockDim;
|
| 598 |
+
|
| 599 |
+
gridDim.y = B;
|
| 600 |
+
gridDim.z = H;
|
| 601 |
+
|
| 602 |
+
torch::Tensor out_real = torch::empty({B, H, N, M}, x.options());
|
| 603 |
+
torch::Tensor out_imag = torch::empty({B, H, N, M}, x.options());
|
| 604 |
+
|
| 605 |
+
//set blockDims
|
| 606 |
+
switch(N){
|
| 607 |
+
case 128:
|
| 608 |
+
blockDim.x = 32;
|
| 609 |
+
blockDim.y = 8;
|
| 610 |
+
break;
|
| 611 |
+
default:
|
| 612 |
+
blockDim.x = 32;
|
| 613 |
+
blockDim.y = 4;
|
| 614 |
+
break;
|
| 615 |
+
}
|
| 616 |
+
|
| 617 |
+
//set gridDim.x
|
| 618 |
+
switch(N){
|
| 619 |
+
case 128:
|
| 620 |
+
switch (M){
|
| 621 |
+
case 16384:
|
| 622 |
+
gridDim.x = 128;
|
| 623 |
+
break;
|
| 624 |
+
case 8192:
|
| 625 |
+
gridDim.x = 64;
|
| 626 |
+
break;
|
| 627 |
+
case 4096:
|
| 628 |
+
gridDim.x = 32;
|
| 629 |
+
break;
|
| 630 |
+
default:
|
| 631 |
+
gridDim.x = 256;
|
| 632 |
+
break;
|
| 633 |
+
}
|
| 634 |
+
break;
|
| 635 |
+
default:
|
| 636 |
+
switch (M){
|
| 637 |
+
case 16384:
|
| 638 |
+
gridDim.x = 256;
|
| 639 |
+
break;
|
| 640 |
+
case 8192:
|
| 641 |
+
gridDim.x = 128;
|
| 642 |
+
break;
|
| 643 |
+
case 4096:
|
| 644 |
+
gridDim.x = 64;
|
| 645 |
+
break;
|
| 646 |
+
default:
|
| 647 |
+
gridDim.x = 512;
|
| 648 |
+
break;
|
| 649 |
+
}
|
| 650 |
+
break;
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
switch (N)
|
| 654 |
+
{
|
| 655 |
+
case 16:
|
| 656 |
+
butterfly_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 657 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 658 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 659 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 661 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 662 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 663 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 664 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 665 |
+
B,
|
| 666 |
+
H,
|
| 667 |
+
N);
|
| 668 |
+
break;
|
| 669 |
+
case 32:
|
| 670 |
+
butterfly_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 671 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 672 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 673 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 678 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 679 |
+
B,
|
| 680 |
+
H,
|
| 681 |
+
N);
|
| 682 |
+
break;
|
| 683 |
+
|
| 684 |
+
case 64:
|
| 685 |
+
gridDim.z = H / 16;
|
| 686 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 687 |
+
|
| 688 |
+
butterfly_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 689 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 690 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 691 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 692 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 693 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 694 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 695 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 696 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 697 |
+
B,
|
| 698 |
+
H,
|
| 699 |
+
N);
|
| 700 |
+
break;
|
| 701 |
+
case 128:
|
| 702 |
+
gridDim.z = H / 16;
|
| 703 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 704 |
+
|
| 705 |
+
butterfly_cuda_kernel_128<<<gridDim, blockDim, 65536>>>(
|
| 706 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 707 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 708 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 709 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 710 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 713 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 714 |
+
B,
|
| 715 |
+
H,
|
| 716 |
+
N);
|
| 717 |
+
break;
|
| 718 |
+
|
| 719 |
+
default:
|
| 720 |
+
printf("Not yet implemented \n");
|
| 721 |
+
break;
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
return {out_real, out_imag};
|
| 725 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda.cu
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
__global__ void butterfly_ifft_cuda_kernel_64(
|
| 15 |
+
const __half2 *__restrict__ x_real,
|
| 16 |
+
const __half2 *__restrict__ x_imag,
|
| 17 |
+
const complex_half_t *__restrict__ d_f,
|
| 18 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 19 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 20 |
+
__half2 *__restrict__ out_real,
|
| 21 |
+
__half2 *__restrict__ out_gate,
|
| 22 |
+
uint B,
|
| 23 |
+
uint H,
|
| 24 |
+
int N)
|
| 25 |
+
{
|
| 26 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 27 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 28 |
+
int idx;
|
| 29 |
+
int shared_offset;
|
| 30 |
+
const int B_Y = blockDim.y;
|
| 31 |
+
const int n = N / B_Y;
|
| 32 |
+
|
| 33 |
+
extern __shared__ half x_real_shared[];
|
| 34 |
+
half *x_imag_shared = &x_real_shared[N * N];
|
| 35 |
+
half *d_f_real = &x_imag_shared[N * N];
|
| 36 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 37 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 38 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 39 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 40 |
+
|
| 41 |
+
half tmp_real, tmp_imag;
|
| 42 |
+
|
| 43 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
|
| 44 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
|
| 45 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 46 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 47 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
|
| 48 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
|
| 49 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 50 |
+
|
| 51 |
+
// #pragma unroll
|
| 52 |
+
for (int i = 0; i < n; i++)
|
| 53 |
+
{
|
| 54 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 55 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 56 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 57 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 58 |
+
|
| 59 |
+
// #pragma unroll
|
| 60 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x;
|
| 61 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 62 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 63 |
+
|
| 64 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 65 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
__syncthreads();
|
| 69 |
+
|
| 70 |
+
for (int i = 0; i < 4; i++)
|
| 71 |
+
{
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int j = 0; j < 4; j++)
|
| 74 |
+
{
|
| 75 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 76 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 77 |
+
}
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
for (int t = 0; t < 16; t++)
|
| 83 |
+
{
|
| 84 |
+
|
| 85 |
+
for (int i = 0; i < n; i++)
|
| 86 |
+
{
|
| 87 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 88 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 89 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 90 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
__syncthreads();
|
| 94 |
+
|
| 95 |
+
for (int i = 0; i < 4; i++)
|
| 96 |
+
{
|
| 97 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
for (int j = 0; j < 4; j++)
|
| 102 |
+
{
|
| 103 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 104 |
+
{
|
| 105 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 106 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 107 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 108 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
for (int i = 0; i < 4; i++)
|
| 113 |
+
{
|
| 114 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 115 |
+
|
| 116 |
+
// bd
|
| 117 |
+
#pragma unroll
|
| 118 |
+
for (int k = 0; k < 4; k++)
|
| 119 |
+
{
|
| 120 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 124 |
+
{
|
| 125 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for (int i = 0; i < 4; i++)
|
| 130 |
+
{
|
| 131 |
+
// ac - bd
|
| 132 |
+
#pragma unroll
|
| 133 |
+
for (int k = 0; k < 4; k++)
|
| 134 |
+
{
|
| 135 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
#pragma unroll
|
| 140 |
+
for (int i = 0; i < 4; i++)
|
| 141 |
+
{
|
| 142 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
__syncthreads();
|
| 146 |
+
|
| 147 |
+
#pragma unroll
|
| 148 |
+
for (int i = 0; i < n; i++)
|
| 149 |
+
{
|
| 150 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 151 |
+
if(out_gate != nullptr){
|
| 152 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 153 |
+
}
|
| 154 |
+
else{
|
| 155 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
__syncthreads();
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
__global__ void butterfly_ifft_cuda_kernel_32(
|
| 164 |
+
const __half2 *__restrict__ x_real,
|
| 165 |
+
const __half2 *__restrict__ x_imag,
|
| 166 |
+
const complex_half_t *__restrict__ d_f,
|
| 167 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 168 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 169 |
+
__half2 *__restrict__ out_real,
|
| 170 |
+
__half2 *__restrict__ out_gate,
|
| 171 |
+
uint B,
|
| 172 |
+
uint H,
|
| 173 |
+
int N)
|
| 174 |
+
{
|
| 175 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 176 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 177 |
+
int idx;
|
| 178 |
+
int shared_offset;
|
| 179 |
+
const int B_Y = blockDim.y;
|
| 180 |
+
const int n = N / B_Y;
|
| 181 |
+
|
| 182 |
+
__shared__ half x_real_shared[32 * 64];
|
| 183 |
+
__shared__ half x_imag_shared[32 * 64];
|
| 184 |
+
__shared__ half d_f_real[32 * 32];
|
| 185 |
+
__shared__ half d_f_imag[32 * 32];
|
| 186 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 187 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 188 |
+
__shared__ half out_real_shared[32 * 64];
|
| 189 |
+
|
| 190 |
+
// #pragma unroll
|
| 191 |
+
for (int i = 0; i < n; i++)
|
| 192 |
+
{
|
| 193 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 194 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 195 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 196 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 197 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 198 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 199 |
+
|
| 200 |
+
// #pragma unroll
|
| 201 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 202 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
__syncthreads();
|
| 206 |
+
|
| 207 |
+
if (threadIdx.y < N / 16)
|
| 208 |
+
{
|
| 209 |
+
half tmp_real, tmp_imag;
|
| 210 |
+
|
| 211 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 212 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 213 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 214 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 215 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
|
| 216 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
|
| 217 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 218 |
+
|
| 219 |
+
int t = threadIdx.y * 32;
|
| 220 |
+
|
| 221 |
+
for (int i = 0; i < 2; i++)
|
| 222 |
+
{
|
| 223 |
+
for (int j = 0; j < 2; j++)
|
| 224 |
+
{
|
| 225 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 226 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 227 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 228 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 229 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 230 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
for (int i = 0; i < 2; i++)
|
| 235 |
+
{
|
| 236 |
+
for (int j = 0; j < 2; j++)
|
| 237 |
+
{
|
| 238 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 239 |
+
{
|
| 240 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 241 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 242 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 243 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
for (int i = 0; i < 2; i++)
|
| 249 |
+
{
|
| 250 |
+
for (int j = 0; j < 2; j++)
|
| 251 |
+
{
|
| 252 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 253 |
+
|
| 254 |
+
// bd
|
| 255 |
+
for (int k = 0; k < 2; k++)
|
| 256 |
+
{
|
| 257 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 261 |
+
{
|
| 262 |
+
acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
for (int i = 0; i < 2; i++)
|
| 268 |
+
{
|
| 269 |
+
for (int j = 0; j < 2; j++)
|
| 270 |
+
{
|
| 271 |
+
// ac - bd
|
| 272 |
+
for (int k = 0; k < 2; k++)
|
| 273 |
+
{
|
| 274 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
for (int i = 0; i < 2; i++)
|
| 280 |
+
{
|
| 281 |
+
for (int j = 0; j < 2; j++)
|
| 282 |
+
{
|
| 283 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
__syncthreads();
|
| 289 |
+
|
| 290 |
+
#pragma unroll
|
| 291 |
+
for (int i = 0; i < n; i++)
|
| 292 |
+
{
|
| 293 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 294 |
+
if(out_gate != nullptr){
|
| 295 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 296 |
+
}
|
| 297 |
+
else{
|
| 298 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
__global__ void butterfly_ifft_cuda_kernel_128(
|
| 305 |
+
const __half2 *__restrict__ x_real,
|
| 306 |
+
const __half2 *__restrict__ x_imag,
|
| 307 |
+
const complex_half_t *__restrict__ d_f,
|
| 308 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 309 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 310 |
+
__half2 *__restrict__ out_real,
|
| 311 |
+
__half2 *__restrict__ out_gate,
|
| 312 |
+
uint B,
|
| 313 |
+
uint H,
|
| 314 |
+
int N)
|
| 315 |
+
{
|
| 316 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 317 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 318 |
+
int idx;
|
| 319 |
+
int shared_offset;
|
| 320 |
+
|
| 321 |
+
const int B_Y = 8;
|
| 322 |
+
const int n = 16;
|
| 323 |
+
|
| 324 |
+
extern __shared__ half real_shared[];
|
| 325 |
+
half *imag_shared = &real_shared[128 * 128];
|
| 326 |
+
half *real_shared_2 = &imag_shared[128 * 128];
|
| 327 |
+
half *imag_shared_2 = &real_shared_2[128 * 128];
|
| 328 |
+
|
| 329 |
+
__half2 tmp_real, tmp_imag;
|
| 330 |
+
|
| 331 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[8][8];
|
| 332 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 333 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 334 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
|
| 335 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
|
| 336 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 337 |
+
|
| 338 |
+
for (int i = 0; i < n; i++)
|
| 339 |
+
{
|
| 340 |
+
for(int j=0; j< 4; j++){
|
| 341 |
+
shared_offset = (threadIdx.y + i * B_Y) * 128 + threadIdx.x + j * blockDim.x;
|
| 342 |
+
real_shared_2[shared_offset] = d_f[shared_offset].real();
|
| 343 |
+
imag_shared_2[shared_offset] = d_f[shared_offset].imag();
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
__syncthreads();
|
| 349 |
+
|
| 350 |
+
for (int i = 0; i < n; i++)
|
| 351 |
+
{
|
| 352 |
+
for(int j=0; j< 2; j++){
|
| 353 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 354 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 355 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 356 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
__syncthreads();
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
for (int i = 0; i < 8; i++){
|
| 364 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 365 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
__syncthreads();
|
| 369 |
+
|
| 370 |
+
for (int t = 0; t < 16; t++)
|
| 371 |
+
{
|
| 372 |
+
|
| 373 |
+
for (int i = 0; i < n; i++)
|
| 374 |
+
{
|
| 375 |
+
for(int j=0; j< 2; j++){
|
| 376 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 377 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 378 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 379 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 380 |
+
}
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
__syncthreads();
|
| 384 |
+
|
| 385 |
+
for (int i = 0; i < 8; i++)
|
| 386 |
+
{
|
| 387 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 388 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
for (int j = 0; j < 8; j++)
|
| 393 |
+
{
|
| 394 |
+
for (int k = 0; k < tw_frag_real[j].num_elements/2; k++)
|
| 395 |
+
{
|
| 396 |
+
tmp_real = __hsub2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]),
|
| 397 |
+
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]));
|
| 398 |
+
tmp_imag = __hadd2(__hmul2(reinterpret_cast<__half2*>(tw_frag_real[j].x)[k], reinterpret_cast<__half2*>(b_frag_imag[j].x)[k]),
|
| 399 |
+
__hmul2(reinterpret_cast<__half2*>(tw_frag_imag[j].x)[k], reinterpret_cast<__half2*>(b_frag_real[j].x)[k]));
|
| 400 |
+
reinterpret_cast<__half2*>(b_frag_real[j].x)[k] = tmp_real;
|
| 401 |
+
reinterpret_cast<__half2*>(b_frag_imag[j].x)[k] = tmp_imag;
|
| 402 |
+
}
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
for (int i = 0; i < 8; i++){
|
| 406 |
+
for (int j = 0; j < 8; j++){
|
| 407 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
__syncthreads();
|
| 412 |
+
|
| 413 |
+
for (int i = 0; i < 8; i++)
|
| 414 |
+
{
|
| 415 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 416 |
+
|
| 417 |
+
// bd
|
| 418 |
+
#pragma unroll
|
| 419 |
+
for (int k = 0; k < 8; k++)
|
| 420 |
+
{
|
| 421 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 425 |
+
{
|
| 426 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
for (int i = 0; i < 8; i++){
|
| 432 |
+
for (int j = 0; j < 8; j++){
|
| 433 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 434 |
+
}
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
__syncthreads();
|
| 438 |
+
|
| 439 |
+
for (int i = 0; i < 8; i++)
|
| 440 |
+
{
|
| 441 |
+
// ac - bd
|
| 442 |
+
#pragma unroll
|
| 443 |
+
for (int k = 0; k < 8; k++)
|
| 444 |
+
{
|
| 445 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
#pragma unroll
|
| 450 |
+
for (int i = 0; i < 8; i++)
|
| 451 |
+
{
|
| 452 |
+
wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
__syncthreads();
|
| 456 |
+
|
| 457 |
+
#pragma unroll
|
| 458 |
+
for (int i = 0; i < n; i++)
|
| 459 |
+
{
|
| 460 |
+
for(int j=0; j< 2; j++){
|
| 461 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 462 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 463 |
+
if(out_gate != nullptr){
|
| 464 |
+
out_real[offset + idx] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[offset + idx]);
|
| 465 |
+
}
|
| 466 |
+
else{
|
| 467 |
+
out_real[offset + idx] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
__syncthreads();
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
__global__ void butterfly_ifft_cuda_kernel_16(
|
| 477 |
+
const __half2 *__restrict__ x_real,
|
| 478 |
+
const __half2 *__restrict__ x_imag,
|
| 479 |
+
const complex_half_t *__restrict__ d_f,
|
| 480 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 481 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 482 |
+
__half2 *__restrict__ out_real,
|
| 483 |
+
__half2 *__restrict__ out_gate,
|
| 484 |
+
uint B,
|
| 485 |
+
uint H,
|
| 486 |
+
int N)
|
| 487 |
+
{
|
| 488 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 489 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 490 |
+
int idx;
|
| 491 |
+
int shared_offset;
|
| 492 |
+
const int B_Y = blockDim.y;
|
| 493 |
+
const int n = N / B_Y;
|
| 494 |
+
|
| 495 |
+
__shared__ half x_real_shared[16 * 64];
|
| 496 |
+
__shared__ half x_imag_shared[16 * 64];
|
| 497 |
+
__shared__ half d_f_real[16 * 16];
|
| 498 |
+
__shared__ half d_f_imag[16 * 16];
|
| 499 |
+
__shared__ half twiddles_real_shared[16 * 64];
|
| 500 |
+
__shared__ half twiddles_imag_shared[16 * 64];
|
| 501 |
+
__shared__ half out_real_shared[16 * 64];
|
| 502 |
+
|
| 503 |
+
// #pragma unroll
|
| 504 |
+
for (int i = 0; i < n; i++)
|
| 505 |
+
{
|
| 506 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 507 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 508 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 509 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 510 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 511 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 512 |
+
|
| 513 |
+
if(threadIdx.x < 16 ){
|
| 514 |
+
shared_offset = (threadIdx.y + i * B_Y) * 16 + threadIdx.x;
|
| 515 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 516 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
__syncthreads();
|
| 521 |
+
|
| 522 |
+
//check if it is better to have one warp do all the multiplication or split between warps
|
| 523 |
+
if (threadIdx.y < 4)
|
| 524 |
+
{
|
| 525 |
+
half tmp_real, tmp_imag;
|
| 526 |
+
|
| 527 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 528 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 529 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 530 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 531 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
|
| 532 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
|
| 533 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 534 |
+
|
| 535 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 536 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 537 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 538 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 539 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 540 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 545 |
+
{
|
| 546 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 547 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 548 |
+
b_frag_real.x[k] = tmp_real;
|
| 549 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 554 |
+
|
| 555 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 556 |
+
|
| 557 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 558 |
+
acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 563 |
+
|
| 564 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 565 |
+
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
__syncthreads();
|
| 569 |
+
|
| 570 |
+
#pragma unroll
|
| 571 |
+
for (int i = 0; i < n; i++)
|
| 572 |
+
{
|
| 573 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 574 |
+
if(out_gate != nullptr){
|
| 575 |
+
out_real[idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x], out_gate[idx]);
|
| 576 |
+
}
|
| 577 |
+
else{
|
| 578 |
+
out_real[idx] = reinterpret_cast<__half2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x];
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
torch::Tensor butterfly_ifft_cuda(
|
| 584 |
+
torch::Tensor x_real,
|
| 585 |
+
torch::Tensor x_imag,
|
| 586 |
+
torch::Tensor d_f,
|
| 587 |
+
torch::Tensor twiddle_factors_real,
|
| 588 |
+
torch::Tensor twiddle_factors_imag,
|
| 589 |
+
std::optional<at::Tensor> out_gate = std::nullopt)
|
| 590 |
+
{
|
| 591 |
+
|
| 592 |
+
uint B = x_real.size(0);
|
| 593 |
+
uint H = x_real.size(1);
|
| 594 |
+
// uint m = x.size(1);
|
| 595 |
+
|
| 596 |
+
// const int TILE_SIZE = 16;
|
| 597 |
+
|
| 598 |
+
dim3 gridDim;
|
| 599 |
+
dim3 blockDim;
|
| 600 |
+
|
| 601 |
+
uint N = x_real.size(2);
|
| 602 |
+
uint M = x_real.size(3);
|
| 603 |
+
gridDim.y = B;
|
| 604 |
+
|
| 605 |
+
blockDim.x = 32;
|
| 606 |
+
blockDim.y = 4;
|
| 607 |
+
|
| 608 |
+
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 609 |
+
gridDim.z = H;
|
| 610 |
+
|
| 611 |
+
//set blockDims
|
| 612 |
+
switch(N){
|
| 613 |
+
case 128:
|
| 614 |
+
blockDim.x = 32;
|
| 615 |
+
blockDim.y = 8;
|
| 616 |
+
break;
|
| 617 |
+
default:
|
| 618 |
+
blockDim.x = 32;
|
| 619 |
+
blockDim.y = 4;
|
| 620 |
+
break;
|
| 621 |
+
}
|
| 622 |
+
|
| 623 |
+
//set gridDim.x
|
| 624 |
+
switch(N){
|
| 625 |
+
case 128:
|
| 626 |
+
switch (M){
|
| 627 |
+
case 16384:
|
| 628 |
+
gridDim.x = 128;
|
| 629 |
+
break;
|
| 630 |
+
case 8192:
|
| 631 |
+
gridDim.x = 64;
|
| 632 |
+
break;
|
| 633 |
+
case 4096:
|
| 634 |
+
gridDim.x = 32;
|
| 635 |
+
break;
|
| 636 |
+
default:
|
| 637 |
+
gridDim.x = 256;
|
| 638 |
+
break;
|
| 639 |
+
}
|
| 640 |
+
break;
|
| 641 |
+
default:
|
| 642 |
+
switch (M){
|
| 643 |
+
case 16384:
|
| 644 |
+
gridDim.x = 256;
|
| 645 |
+
break;
|
| 646 |
+
case 8192:
|
| 647 |
+
gridDim.x = 128;
|
| 648 |
+
break;
|
| 649 |
+
case 4096:
|
| 650 |
+
gridDim.x = 64;
|
| 651 |
+
break;
|
| 652 |
+
default:
|
| 653 |
+
gridDim.x = 512;
|
| 654 |
+
break;
|
| 655 |
+
}
|
| 656 |
+
break;
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
switch (N)
|
| 660 |
+
{
|
| 661 |
+
case 16:
|
| 662 |
+
butterfly_ifft_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 663 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 664 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 665 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 666 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 667 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 668 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 669 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 670 |
+
B,
|
| 671 |
+
H,
|
| 672 |
+
N);
|
| 673 |
+
break;
|
| 674 |
+
case 32:
|
| 675 |
+
butterfly_ifft_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 676 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 677 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 678 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 679 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 680 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 681 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 682 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 683 |
+
B,
|
| 684 |
+
H,
|
| 685 |
+
N);
|
| 686 |
+
break;
|
| 687 |
+
case 64:
|
| 688 |
+
gridDim.z = H / 16;
|
| 689 |
+
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 690 |
+
butterfly_ifft_cuda_kernel_64<<<gridDim, blockDim, 8 * N * N * sizeof(half)>>>(
|
| 691 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 692 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 693 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 694 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 695 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 696 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 697 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 698 |
+
B,
|
| 699 |
+
H,
|
| 700 |
+
N);
|
| 701 |
+
break;
|
| 702 |
+
|
| 703 |
+
case 128:
|
| 704 |
+
gridDim.z = H / 16;
|
| 705 |
+
cudaFuncSetAttribute(&butterfly_ifft_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536*2);
|
| 706 |
+
butterfly_ifft_cuda_kernel_128<<<gridDim, blockDim, 65536*2>>>(
|
| 707 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 708 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 709 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 710 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 711 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 712 |
+
static_cast<__half2 *>(out.data_ptr()),
|
| 713 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 714 |
+
B,
|
| 715 |
+
H,
|
| 716 |
+
N);
|
| 717 |
+
break;
|
| 718 |
+
default:
|
| 719 |
+
printf("Not implemented\n");
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
return out;
|
| 723 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_ifft_cuda_bf16.cu
ADDED
|
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include <cuda_runtime.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_64(
|
| 16 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 17 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 18 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int N)
|
| 27 |
+
{
|
| 28 |
+
const int offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 29 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 30 |
+
int idx;
|
| 31 |
+
int shared_offset;
|
| 32 |
+
const int B_Y = blockDim.y;
|
| 33 |
+
const int n = N / B_Y;
|
| 34 |
+
|
| 35 |
+
extern __shared__ __nv_bfloat16 x_real_shared[];
|
| 36 |
+
__nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
|
| 37 |
+
__nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
|
| 38 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 39 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 40 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 41 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 42 |
+
|
| 43 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 44 |
+
|
| 45 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4][4];
|
| 46 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4][4];
|
| 47 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 48 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 49 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
|
| 50 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
|
| 51 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 52 |
+
|
| 53 |
+
// #pragma unroll
|
| 54 |
+
for (int i = 0; i < n; i++)
|
| 55 |
+
{
|
| 56 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 57 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 58 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 59 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 60 |
+
|
| 61 |
+
// #pragma unroll
|
| 62 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 63 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 64 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
__syncthreads();
|
| 68 |
+
|
| 69 |
+
for (int i = 0; i < 4; i++)
|
| 70 |
+
{
|
| 71 |
+
#pragma unroll
|
| 72 |
+
for (int j = 0; j < 4; j++)
|
| 73 |
+
{
|
| 74 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 75 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 76 |
+
}
|
| 77 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
for (int t = 0; t < 16; t++)
|
| 82 |
+
{
|
| 83 |
+
|
| 84 |
+
for (int i = 0; i < n; i++)
|
| 85 |
+
{
|
| 86 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 87 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 88 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 89 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
__syncthreads();
|
| 93 |
+
|
| 94 |
+
for (int i = 0; i < 4; i++)
|
| 95 |
+
{
|
| 96 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 97 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
for (int j = 0; j < 4; j++)
|
| 101 |
+
{
|
| 102 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 103 |
+
{
|
| 104 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 105 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 106 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 107 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
for (int i = 0; i < 4; i++)
|
| 112 |
+
{
|
| 113 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 114 |
+
|
| 115 |
+
// bd
|
| 116 |
+
#pragma unroll
|
| 117 |
+
for (int k = 0; k < 4; k++)
|
| 118 |
+
{
|
| 119 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 123 |
+
{
|
| 124 |
+
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
for (int i = 0; i < 4; i++)
|
| 129 |
+
{
|
| 130 |
+
// ac - bd
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (int k = 0; k < 4; k++)
|
| 133 |
+
{
|
| 134 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
#pragma unroll
|
| 139 |
+
for (int i = 0; i < 4; i++)
|
| 140 |
+
{
|
| 141 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
__syncthreads();
|
| 145 |
+
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int i = 0; i < n; i++)
|
| 148 |
+
{
|
| 149 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x + t * 64 * 32 * gridDim.x;
|
| 150 |
+
if(out_gate != nullptr){
|
| 151 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]); ;
|
| 152 |
+
}else{
|
| 153 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
__syncthreads();
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_32(
|
| 162 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 163 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 164 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 165 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 166 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 167 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 168 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 169 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 170 |
+
uint B,
|
| 171 |
+
uint H,
|
| 172 |
+
int N)
|
| 173 |
+
{
|
| 174 |
+
const int offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 175 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 176 |
+
int idx;
|
| 177 |
+
int shared_offset;
|
| 178 |
+
const int B_Y = blockDim.y;
|
| 179 |
+
const int n = N / B_Y;
|
| 180 |
+
|
| 181 |
+
__shared__ __nv_bfloat16 x_real_shared[32 * 64];
|
| 182 |
+
__shared__ __nv_bfloat16 x_imag_shared[32 * 64];
|
| 183 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 184 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 185 |
+
__shared__ float out_real_shared[32 * 64];
|
| 186 |
+
|
| 187 |
+
// #pragma unroll
|
| 188 |
+
for (int i = 0; i < n; i++)
|
| 189 |
+
{
|
| 190 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 191 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 192 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 193 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 194 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 195 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
__syncthreads();
|
| 199 |
+
|
| 200 |
+
if (threadIdx.y < N / 16)
|
| 201 |
+
{
|
| 202 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 203 |
+
|
| 204 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 205 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 206 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 207 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 208 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
|
| 209 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
|
| 210 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 211 |
+
|
| 212 |
+
int t = threadIdx.y * 32;
|
| 213 |
+
|
| 214 |
+
for (int i = 0; i < 2; i++)
|
| 215 |
+
{
|
| 216 |
+
for (int j = 0; j < 2; j++)
|
| 217 |
+
{
|
| 218 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 219 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 220 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 221 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 222 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 223 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
for (int i = 0; i < 2; i++)
|
| 228 |
+
{
|
| 229 |
+
for (int j = 0; j < 2; j++)
|
| 230 |
+
{
|
| 231 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 232 |
+
{
|
| 233 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 234 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 235 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 236 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
for (int i = 0; i < 2; i++)
|
| 242 |
+
{
|
| 243 |
+
for (int j = 0; j < 2; j++)
|
| 244 |
+
{
|
| 245 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 246 |
+
|
| 247 |
+
// bd
|
| 248 |
+
for (int k = 0; k < 2; k++)
|
| 249 |
+
{
|
| 250 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 254 |
+
{
|
| 255 |
+
acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
for (int i = 0; i < 2; i++)
|
| 261 |
+
{
|
| 262 |
+
for (int j = 0; j < 2; j++)
|
| 263 |
+
{
|
| 264 |
+
// ac - bd
|
| 265 |
+
for (int k = 0; k < 2; k++)
|
| 266 |
+
{
|
| 267 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
for (int i = 0; i < 2; i++)
|
| 273 |
+
{
|
| 274 |
+
for (int j = 0; j < 2; j++)
|
| 275 |
+
{
|
| 276 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 277 |
+
}
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
__syncthreads();
|
| 282 |
+
|
| 283 |
+
#pragma unroll
|
| 284 |
+
for (int i = 0; i < n; i++)
|
| 285 |
+
{
|
| 286 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 287 |
+
if(out_gate != nullptr){
|
| 288 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 289 |
+
}else{
|
| 290 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_128(
|
| 297 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 298 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 299 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 300 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 301 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 302 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 303 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 304 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 305 |
+
uint B,
|
| 306 |
+
uint H,
|
| 307 |
+
int N)
|
| 308 |
+
{
|
| 309 |
+
const int offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 310 |
+
const int tw_offset = blockIdx.x * 64 + threadIdx.x;
|
| 311 |
+
int idx;
|
| 312 |
+
int shared_offset;
|
| 313 |
+
const int B_Y = blockDim.y;
|
| 314 |
+
const int n = N / B_Y;
|
| 315 |
+
|
| 316 |
+
extern __shared__ __nv_bfloat16 real_shared[];
|
| 317 |
+
__nv_bfloat16 *imag_shared = &real_shared[128 * 128];
|
| 318 |
+
__nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
|
| 319 |
+
__nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
|
| 320 |
+
|
| 321 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 322 |
+
|
| 323 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[8][8];
|
| 324 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 325 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 326 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
|
| 327 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
|
| 328 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 329 |
+
|
| 330 |
+
for (int i = 0; i < n; i++)
|
| 331 |
+
{
|
| 332 |
+
for(int j=0; j< 2; j++){
|
| 333 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 334 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
|
| 335 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
for (int i = 0; i < n; i++)
|
| 340 |
+
{
|
| 341 |
+
for(int j=0; j< 2; j++){
|
| 342 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x;
|
| 343 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 344 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 345 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
__syncthreads();
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
for (int i = 0; i < 8; i++){
|
| 353 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 354 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
__syncthreads();
|
| 358 |
+
|
| 359 |
+
for (int t = 0; t < 16; t++)
|
| 360 |
+
{
|
| 361 |
+
for (int i = 0; i < 8; i++){
|
| 362 |
+
for (int j = 0; j < 8; j++){
|
| 363 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
for (int i = 0; i < n; i++)
|
| 368 |
+
{
|
| 369 |
+
for(int j=0; j< 2; j++){
|
| 370 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 371 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 372 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[offset + idx];
|
| 373 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[offset + idx];
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
__syncthreads();
|
| 378 |
+
|
| 379 |
+
for (int i = 0; i < 8; i++)
|
| 380 |
+
{
|
| 381 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 382 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
for (int j = 0; j < 8; j++)
|
| 387 |
+
{
|
| 388 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 389 |
+
{
|
| 390 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 391 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 392 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 393 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
for (int i = 0; i < 8; i++)
|
| 398 |
+
{
|
| 399 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 400 |
+
|
| 401 |
+
// bd
|
| 402 |
+
#pragma unroll
|
| 403 |
+
for (int k = 0; k < 8; k++)
|
| 404 |
+
{
|
| 405 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 409 |
+
{
|
| 410 |
+
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
for (int i = 0; i < 8; i++){
|
| 415 |
+
for (int j = 0; j < 8; j++){
|
| 416 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 417 |
+
}
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
for (int i = 0; i < 8; i++)
|
| 421 |
+
{
|
| 422 |
+
// ac - bd
|
| 423 |
+
#pragma unroll
|
| 424 |
+
for (int k = 0; k < 8; k++)
|
| 425 |
+
{
|
| 426 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
#pragma unroll
|
| 431 |
+
for (int i = 0; i < 8; i++)
|
| 432 |
+
{
|
| 433 |
+
//wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 434 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
__syncthreads();
|
| 438 |
+
|
| 439 |
+
#pragma unroll
|
| 440 |
+
for (int i = 0; i < n; i++)
|
| 441 |
+
{
|
| 442 |
+
for(int j=0; j< 2; j++){
|
| 443 |
+
idx = (threadIdx.y + i * B_Y) * 32 * 2 * gridDim.x + j * blockDim.x + t * 128 * 32 * 2 * gridDim.x;
|
| 444 |
+
shared_offset = (threadIdx.y + i * B_Y) * 64 + threadIdx.x + j * blockDim.x;
|
| 445 |
+
if(out_gate != nullptr){
|
| 446 |
+
out_real[offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[offset + idx]);
|
| 447 |
+
}else{
|
| 448 |
+
out_real[offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
__syncthreads();
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
__global__ void butterfly_ifft_bf16_cuda_kernel_16(
|
| 458 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 459 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 460 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 461 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 462 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 463 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 464 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 465 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 466 |
+
uint B,
|
| 467 |
+
uint H,
|
| 468 |
+
int N)
|
| 469 |
+
{
|
| 470 |
+
const int offset = blockIdx.y * H * 16 * 32 * gridDim.x + blockIdx.z * 16 * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 471 |
+
const int tw_offset = blockIdx.x * 32 + threadIdx.x;
|
| 472 |
+
int idx;
|
| 473 |
+
int shared_offset;
|
| 474 |
+
const int B_Y = blockDim.y;
|
| 475 |
+
const int n = N / B_Y;
|
| 476 |
+
|
| 477 |
+
__shared__ __nv_bfloat16 x_real_shared[16 * 64];
|
| 478 |
+
__shared__ __nv_bfloat16 x_imag_shared[16 * 64];
|
| 479 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[16 * 64];
|
| 480 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[16 * 64];
|
| 481 |
+
__shared__ float out_real_shared[16 * 64];
|
| 482 |
+
|
| 483 |
+
// #pragma unroll
|
| 484 |
+
for (int i = 0; i < n; i++)
|
| 485 |
+
{
|
| 486 |
+
idx = (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 487 |
+
shared_offset = (threadIdx.y + i * B_Y) * 32 + threadIdx.x;
|
| 488 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 489 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 490 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[tw_offset + idx];
|
| 491 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[tw_offset + idx];
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
__syncthreads();
|
| 495 |
+
|
| 496 |
+
if (threadIdx.y < 4)
|
| 497 |
+
{
|
| 498 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 499 |
+
|
| 500 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 501 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 502 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 503 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 504 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
|
| 505 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
|
| 506 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 507 |
+
|
| 508 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 509 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 510 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 511 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 512 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 513 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 517 |
+
{
|
| 518 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 519 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 520 |
+
b_frag_real.x[k] = tmp_real;
|
| 521 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 527 |
+
|
| 528 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 529 |
+
|
| 530 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 531 |
+
acc_frag_real.x[k] = - acc_frag_real.x[k];
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 535 |
+
|
| 536 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 537 |
+
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
__syncthreads();
|
| 541 |
+
|
| 542 |
+
#pragma unroll
|
| 543 |
+
for (int i = 0; i < n; i++)
|
| 544 |
+
{
|
| 545 |
+
idx = offset + (threadIdx.y + i * B_Y) * 32 * gridDim.x;
|
| 546 |
+
if(out_gate != nullptr){
|
| 547 |
+
out_real[idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]), out_gate[idx]);
|
| 548 |
+
}else{
|
| 549 |
+
out_real[idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[(threadIdx.y + i * B_Y) * 32 + threadIdx.x]);
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
torch::Tensor butterfly_ifft_bf16_cuda(
|
| 556 |
+
torch::Tensor x_real,
|
| 557 |
+
torch::Tensor x_imag,
|
| 558 |
+
torch::Tensor d_f_real,
|
| 559 |
+
torch::Tensor d_f_imag,
|
| 560 |
+
torch::Tensor twiddle_factors_real,
|
| 561 |
+
torch::Tensor twiddle_factors_imag,
|
| 562 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 563 |
+
)
|
| 564 |
+
{
|
| 565 |
+
|
| 566 |
+
uint B = x_real.size(0);
|
| 567 |
+
uint H = x_real.size(1);
|
| 568 |
+
// uint m = x.size(1);
|
| 569 |
+
|
| 570 |
+
// const int TILE_SIZE = 16;
|
| 571 |
+
|
| 572 |
+
dim3 gridDim;
|
| 573 |
+
dim3 blockDim;
|
| 574 |
+
|
| 575 |
+
uint N = x_real.size(2);
|
| 576 |
+
uint M = x_real.size(3);
|
| 577 |
+
gridDim.y = B;
|
| 578 |
+
|
| 579 |
+
blockDim.x = 32;
|
| 580 |
+
blockDim.y = 4;
|
| 581 |
+
|
| 582 |
+
torch::Tensor out = torch::empty({B, H, N, M}, x_real.options());
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
//set blockDims
|
| 586 |
+
switch(N){
|
| 587 |
+
case 128:
|
| 588 |
+
blockDim.x = 32;
|
| 589 |
+
blockDim.y = 8;
|
| 590 |
+
break;
|
| 591 |
+
default:
|
| 592 |
+
blockDim.x = 32;
|
| 593 |
+
blockDim.y = 4;
|
| 594 |
+
break;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
//set gridDim.x
|
| 598 |
+
switch(N){
|
| 599 |
+
case 128:
|
| 600 |
+
switch (M){
|
| 601 |
+
case 16384:
|
| 602 |
+
gridDim.x = 128;
|
| 603 |
+
break;
|
| 604 |
+
case 8192:
|
| 605 |
+
gridDim.x = 64;
|
| 606 |
+
break;
|
| 607 |
+
case 4096:
|
| 608 |
+
gridDim.x = 32;
|
| 609 |
+
break;
|
| 610 |
+
default:
|
| 611 |
+
gridDim.x = 256;
|
| 612 |
+
break;
|
| 613 |
+
}
|
| 614 |
+
break;
|
| 615 |
+
default:
|
| 616 |
+
switch (M){
|
| 617 |
+
case 16384:
|
| 618 |
+
gridDim.x = 256;
|
| 619 |
+
break;
|
| 620 |
+
case 8192:
|
| 621 |
+
gridDim.x = 128;
|
| 622 |
+
break;
|
| 623 |
+
case 4096:
|
| 624 |
+
gridDim.x = 64;
|
| 625 |
+
break;
|
| 626 |
+
default:
|
| 627 |
+
gridDim.x = 512;
|
| 628 |
+
break;
|
| 629 |
+
}
|
| 630 |
+
break;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
switch (N)
|
| 635 |
+
{
|
| 636 |
+
case 16:
|
| 637 |
+
gridDim.z = H;
|
| 638 |
+
butterfly_ifft_bf16_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 639 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 640 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 641 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 642 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 643 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 644 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 645 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 646 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 647 |
+
B,
|
| 648 |
+
H,
|
| 649 |
+
N);
|
| 650 |
+
break;
|
| 651 |
+
|
| 652 |
+
case 32:
|
| 653 |
+
gridDim.z = H;
|
| 654 |
+
butterfly_ifft_bf16_cuda_kernel_32<<<gridDim, blockDim>>>(
|
| 655 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 656 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 657 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 658 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 659 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 661 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 662 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 663 |
+
B,
|
| 664 |
+
H,
|
| 665 |
+
N);
|
| 666 |
+
break;
|
| 667 |
+
case 64:
|
| 668 |
+
gridDim.z = H / 16;
|
| 669 |
+
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_64, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 670 |
+
butterfly_ifft_bf16_cuda_kernel_64<<<gridDim, blockDim, 78000>>>(
|
| 671 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 672 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 673 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 676 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 677 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 678 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 679 |
+
B,
|
| 680 |
+
H,
|
| 681 |
+
N);
|
| 682 |
+
break;
|
| 683 |
+
|
| 684 |
+
case 128:
|
| 685 |
+
gridDim.z = H / 16;
|
| 686 |
+
cudaFuncSetAttribute(&butterfly_ifft_bf16_cuda_kernel_128, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 687 |
+
butterfly_ifft_bf16_cuda_kernel_128<<<gridDim, blockDim, 65536 * 2>>>(
|
| 688 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 689 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 690 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 691 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 692 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 693 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 694 |
+
static_cast<__nv_bfloat162 *>(out.data_ptr()),
|
| 695 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 696 |
+
B,
|
| 697 |
+
H,
|
| 698 |
+
N);
|
| 699 |
+
break;
|
| 700 |
+
default:
|
| 701 |
+
printf("Not implemented\n");
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
return out;
|
| 705 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda.cu
ADDED
|
@@ -0,0 +1,871 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cmath>
|
| 9 |
+
#include <cuda_fp16.h>
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
template <int K>
|
| 16 |
+
__global__ void butterfly_padded_cuda_kernel_64(
|
| 17 |
+
const __half2 *__restrict__ x,
|
| 18 |
+
const __half2 *__restrict__ x_gate,
|
| 19 |
+
const complex_half_t *__restrict__ d_f,
|
| 20 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__half2 *__restrict__ out_real,
|
| 23 |
+
__half2 *__restrict__ out_imag,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int M)
|
| 27 |
+
{
|
| 28 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 29 |
+
const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
|
| 30 |
+
const int out_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x;
|
| 31 |
+
int idx;
|
| 32 |
+
int t_offset;
|
| 33 |
+
int out_t_offset;
|
| 34 |
+
int shared_offset;
|
| 35 |
+
const int N = 64;
|
| 36 |
+
|
| 37 |
+
extern __shared__ half x_shared[];
|
| 38 |
+
half *d_f_real = &x_shared[K * 16 * N];
|
| 39 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 40 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 41 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 42 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 43 |
+
half *out_imag_shared = &out_real_shared[N * N];
|
| 44 |
+
|
| 45 |
+
// #pragma unroll
|
| 46 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 47 |
+
{
|
| 48 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 49 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 50 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 51 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 52 |
+
|
| 53 |
+
// #pragma unroll
|
| 54 |
+
shared_offset = i * 64 + threadIdx.x;
|
| 55 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 56 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 57 |
+
|
| 58 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 59 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
__half2 tmp_real, tmp_imag;
|
| 63 |
+
|
| 64 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4];
|
| 65 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 66 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 67 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4];
|
| 68 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][4];
|
| 69 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];
|
| 70 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[4];
|
| 71 |
+
|
| 72 |
+
__syncthreads();
|
| 73 |
+
|
| 74 |
+
for (int i = 0; i < 4; i++)
|
| 75 |
+
{
|
| 76 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real + i * N * 16 + threadIdx.y * 16, N);
|
| 77 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag + i * N * 16 + threadIdx.y * 16, N);
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 79 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
for (int t = 0; t < 16; t++)
|
| 83 |
+
{
|
| 84 |
+
t_offset = t * M/2;
|
| 85 |
+
out_t_offset = t * 64 * 32 * gridDim.x;
|
| 86 |
+
|
| 87 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 88 |
+
{
|
| 89 |
+
if(i < K * 16){
|
| 90 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 91 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 92 |
+
if(x_gate != nullptr){
|
| 93 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2half2_rn(0.0f, 0.0f);
|
| 94 |
+
}
|
| 95 |
+
else{
|
| 96 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2half2_rn(0.0f, 0.0f);
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
__syncthreads();
|
| 102 |
+
|
| 103 |
+
for (int i = 0; i < K; i++)
|
| 104 |
+
{
|
| 105 |
+
for (int j = 0; j < 4; j++)
|
| 106 |
+
{
|
| 107 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
#pragma unroll
|
| 112 |
+
for (int j = 0; j < 4; j++)
|
| 113 |
+
{
|
| 114 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 115 |
+
|
| 116 |
+
for (int k = 0; k < K; k++)
|
| 117 |
+
{
|
| 118 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
#pragma unroll
|
| 123 |
+
|
| 124 |
+
for (int j = 0; j < 4; j++)
|
| 125 |
+
{
|
| 126 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 127 |
+
|
| 128 |
+
for (int k = 0; k < K; k++)
|
| 129 |
+
{
|
| 130 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
#pragma unroll
|
| 135 |
+
for (int j = 0; j < 4; j++)
|
| 136 |
+
{
|
| 137 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 138 |
+
{
|
| 139 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 140 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 141 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 142 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
for (int j = 0; j < 4; j++)
|
| 147 |
+
{
|
| 148 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 149 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
__syncthreads();
|
| 153 |
+
|
| 154 |
+
#pragma unroll
|
| 155 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 156 |
+
{
|
| 157 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 158 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 159 |
+
|
| 160 |
+
out_real[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
|
| 161 |
+
out_imag[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_imag_shared)[shared_offset];
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
__syncthreads();
|
| 165 |
+
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
template <int K>
|
| 171 |
+
__global__ void butterfly_padded_cuda_kernel_128(
|
| 172 |
+
const __half2 *__restrict__ x,
|
| 173 |
+
const __half2 *__restrict__ x_gate,
|
| 174 |
+
const complex_half_t *__restrict__ d_f,
|
| 175 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 176 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 177 |
+
__half2 *__restrict__ out_real,
|
| 178 |
+
__half2 *__restrict__ out_imag,
|
| 179 |
+
uint B,
|
| 180 |
+
uint H,
|
| 181 |
+
int M)
|
| 182 |
+
{
|
| 183 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 184 |
+
const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
|
| 185 |
+
const int out_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x;
|
| 186 |
+
const int N = 128;
|
| 187 |
+
int idx;
|
| 188 |
+
int t_offset;
|
| 189 |
+
int out_t_offset;
|
| 190 |
+
int shared_offset;
|
| 191 |
+
|
| 192 |
+
extern __shared__ half shared_real[];
|
| 193 |
+
half *shared_imag = &shared_real[128 * 128];
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8];
|
| 197 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 198 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 199 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8];
|
| 200 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][8];
|
| 201 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
|
| 202 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];
|
| 203 |
+
|
| 204 |
+
for (int i = threadIdx.y ; i < N; i+=blockDim.y)
|
| 205 |
+
{
|
| 206 |
+
for(int j=0; j< 4; j++){
|
| 207 |
+
shared_offset = i * 128 + threadIdx.x + j * blockDim.x;
|
| 208 |
+
shared_real[shared_offset] = d_f[shared_offset].real();
|
| 209 |
+
shared_imag[shared_offset] = d_f[shared_offset].imag();
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
__syncthreads();
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
for (int i = 0; i < 8; i++){
|
| 217 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 218 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
__syncthreads();
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 227 |
+
{
|
| 228 |
+
for(int j=0; j< 2; j++){
|
| 229 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 230 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 231 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = twiddle_factors_real[idx];
|
| 232 |
+
reinterpret_cast<__half2*>(shared_imag)[shared_offset] = twiddle_factors_imag[idx];
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
__syncthreads();
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
for (int i = 0; i < 8; i++){
|
| 240 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 241 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
__syncthreads();
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
for(int t=0; t< 16; t++){
|
| 248 |
+
t_offset = t * M/2;
|
| 249 |
+
out_t_offset = t * 128 * 32 * 2 * gridDim.x;
|
| 250 |
+
|
| 251 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 252 |
+
{
|
| 253 |
+
if(i < K * 16){
|
| 254 |
+
for(int j=0; j< 2; j++){
|
| 255 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 256 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 257 |
+
if(x_gate != nullptr){
|
| 258 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2half2_rn(0.0f, 0.0f);
|
| 259 |
+
}
|
| 260 |
+
else{
|
| 261 |
+
reinterpret_cast<__half2*>(shared_real)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2half2_rn(0.0f, 0.0f);
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
__syncthreads();
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
for (int i = 0; i < K; i++)
|
| 272 |
+
{
|
| 273 |
+
for (int j = 0; j < 8; j++)
|
| 274 |
+
{
|
| 275 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
__syncthreads();
|
| 280 |
+
|
| 281 |
+
#pragma unroll
|
| 282 |
+
for (int j = 0; j < 8; j++)
|
| 283 |
+
{
|
| 284 |
+
wmma::fill_fragment(acc_frag_real[j], __float2half(0.0f));
|
| 285 |
+
|
| 286 |
+
for (int k = 0; k < K; k++)
|
| 287 |
+
{
|
| 288 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
#pragma unroll
|
| 293 |
+
|
| 294 |
+
for (int j = 0; j < 8; j++)
|
| 295 |
+
{
|
| 296 |
+
wmma::fill_fragment(acc_frag_imag[j], __float2half(0.0f));
|
| 297 |
+
|
| 298 |
+
for (int k = 0; k < K; k++)
|
| 299 |
+
{
|
| 300 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
__half2 tmp_real, tmp_imag;
|
| 305 |
+
#pragma unroll
|
| 306 |
+
for (int j = 0; j < 8; j++)
|
| 307 |
+
{
|
| 308 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 309 |
+
{
|
| 310 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k];
|
| 311 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k];
|
| 312 |
+
reinterpret_cast<__half2 *>(acc_frag_real[j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]));
|
| 313 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[j].x)[k]));
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
for (int j = 0; j < 8; j++)
|
| 318 |
+
{
|
| 319 |
+
wmma::store_matrix_sync(shared_real + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 320 |
+
wmma::store_matrix_sync(shared_imag + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
__syncthreads();
|
| 324 |
+
|
| 325 |
+
#pragma unroll
|
| 326 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 327 |
+
{
|
| 328 |
+
for(int j=0; j< 2; j++){
|
| 329 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 330 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 331 |
+
|
| 332 |
+
out_real[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(shared_real)[shared_offset];
|
| 333 |
+
out_imag[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(shared_imag)[shared_offset];
|
| 334 |
+
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
__syncthreads();
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
template <int K>
|
| 343 |
+
__global__ void butterfly_padded_cuda_kernel_32(
|
| 344 |
+
const __half2 *__restrict__ x,
|
| 345 |
+
const __half2 *__restrict__ x_gate,
|
| 346 |
+
const complex_half_t *__restrict__ d_f,
|
| 347 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 348 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 349 |
+
__half2 *__restrict__ out_real,
|
| 350 |
+
__half2 *__restrict__ out_imag,
|
| 351 |
+
uint B,
|
| 352 |
+
uint H,
|
| 353 |
+
int M)
|
| 354 |
+
{
|
| 355 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 356 |
+
const int N = 32;
|
| 357 |
+
__shared__ half x_shared[K * 16 * 64];
|
| 358 |
+
__shared__ half d_f_real[32 * 32];
|
| 359 |
+
__shared__ half d_f_imag[32 * 32];
|
| 360 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 361 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 362 |
+
__shared__ half out_real_shared[32 * 64];
|
| 363 |
+
__shared__ half out_imag_shared[32 * 64];
|
| 364 |
+
|
| 365 |
+
const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 366 |
+
const int out_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
for(int i = threadIdx.y; i<32; i+=blockDim.y){
|
| 370 |
+
int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 371 |
+
int shared_offset = i * 32 + threadIdx.x;
|
| 372 |
+
|
| 373 |
+
if(i < K * 16){
|
| 374 |
+
if(x_gate != nullptr){
|
| 375 |
+
reinterpret_cast<__half2*>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[offset + idx], x_gate[offset + idx]) : __floats2half2_rn(0.0f, 0.0f);
|
| 376 |
+
}
|
| 377 |
+
else{
|
| 378 |
+
reinterpret_cast<__half2*>(x_shared)[shared_offset] = idx < max_idx ? x[offset + idx] : __floats2half2_rn(0.0f, 0.0f);
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 382 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 383 |
+
|
| 384 |
+
// #pragma unroll
|
| 385 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 386 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
__syncthreads();
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
if (threadIdx.y < N / 16)
|
| 394 |
+
{
|
| 395 |
+
__half2 tmp_real, tmp_imag;
|
| 396 |
+
|
| 397 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
|
| 398 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 399 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 400 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
|
| 401 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag[K][2];
|
| 402 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
|
| 403 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];
|
| 404 |
+
|
| 405 |
+
int t = threadIdx.y * 32;
|
| 406 |
+
|
| 407 |
+
for (int i = 0; i < 2; i++)
|
| 408 |
+
{
|
| 409 |
+
for (int j = 0; j < 2; j++)
|
| 410 |
+
{
|
| 411 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 412 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 413 |
+
if(i<K){
|
| 414 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 415 |
+
}
|
| 416 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 417 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
#pragma unroll
|
| 422 |
+
for (int i = 0; i < 2; i++)
|
| 423 |
+
{
|
| 424 |
+
for (int j = 0; j < 2; j++)
|
| 425 |
+
{
|
| 426 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 427 |
+
|
| 428 |
+
for (int k = 0; k < K; k++)
|
| 429 |
+
{
|
| 430 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
#pragma unroll
|
| 436 |
+
for (int i = 0; i < 2; i++)
|
| 437 |
+
{
|
| 438 |
+
for (int j = 0; j < 2; j++)
|
| 439 |
+
{
|
| 440 |
+
wmma::fill_fragment(acc_frag_imag[i][j], __float2half(0.0f));
|
| 441 |
+
|
| 442 |
+
for (int k = 0; k < K; k++)
|
| 443 |
+
{
|
| 444 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
#pragma unroll
|
| 450 |
+
for (int i = 0; i < 2; i++)
|
| 451 |
+
{
|
| 452 |
+
for (int j = 0; j < 2; j++)
|
| 453 |
+
{
|
| 454 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 455 |
+
{
|
| 456 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k];
|
| 457 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k];
|
| 458 |
+
reinterpret_cast<__half2 *>(acc_frag_real[i][j].x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]));
|
| 459 |
+
reinterpret_cast<__half2 *>(acc_frag_imag[i][j].x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag[i][j].x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real[i][j].x)[k]));
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
for (int i = 0; i < 2; i++)
|
| 465 |
+
{
|
| 466 |
+
for (int j = 0; j < 2; j++)
|
| 467 |
+
{
|
| 468 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 469 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 470 |
+
}
|
| 471 |
+
}
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
__syncthreads();
|
| 475 |
+
|
| 476 |
+
// int idx = offset + threadIdx.y * 32 + blockIdx.x * 32 + threadIdx.x;
|
| 477 |
+
for(int i = threadIdx.y; i<32; i+=blockDim.y){
|
| 478 |
+
int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 479 |
+
out_real[out_offset + idx] = reinterpret_cast<__half2*>(out_real_shared)[i * 32 + threadIdx.x];
|
| 480 |
+
out_imag[out_offset + idx] = reinterpret_cast<__half2*>(out_imag_shared)[i * 32 + threadIdx.x];
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
__global__ void butterfly_padded_cuda_kernel_16(
|
| 486 |
+
const __half2 *__restrict__ x,
|
| 487 |
+
const __half2 *__restrict__ x_gate,
|
| 488 |
+
const complex_half_t *__restrict__ d_f,
|
| 489 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 490 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 491 |
+
__half2 *__restrict__ out_real,
|
| 492 |
+
__half2 *__restrict__ out_imag,
|
| 493 |
+
uint B,
|
| 494 |
+
uint H,
|
| 495 |
+
int M)
|
| 496 |
+
{
|
| 497 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 498 |
+
const int N = 16;
|
| 499 |
+
const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 500 |
+
const int out_offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
__shared__ half x_shared[N * 64];
|
| 505 |
+
__shared__ half d_f_real[N * N];
|
| 506 |
+
__shared__ half d_f_imag[N * N];
|
| 507 |
+
__shared__ half twiddles_real_shared[N * 64];
|
| 508 |
+
__shared__ half twiddles_imag_shared[N * 64];
|
| 509 |
+
__shared__ half out_real_shared[N * 64];
|
| 510 |
+
__shared__ half out_imag_shared[N * 64];
|
| 511 |
+
|
| 512 |
+
// #pragma unroll
|
| 513 |
+
for(int i = threadIdx.y; i<N; i+=blockDim.y){
|
| 514 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 515 |
+
int shared_offset = i * blockDim.x + threadIdx.x;
|
| 516 |
+
|
| 517 |
+
if(x_gate != NULL){
|
| 518 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2half2_rn(0.0f, 0.0f);
|
| 519 |
+
}
|
| 520 |
+
else{
|
| 521 |
+
reinterpret_cast<__half2 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2half2_rn(0.0f, 0.0f);
|
| 522 |
+
}
|
| 523 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 524 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 525 |
+
|
| 526 |
+
// #pragma unroll
|
| 527 |
+
|
| 528 |
+
if(threadIdx.x < 16 ){
|
| 529 |
+
shared_offset = i * 16 + threadIdx.x;
|
| 530 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 531 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 532 |
+
}
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
__syncthreads();
|
| 536 |
+
|
| 537 |
+
if (threadIdx.y < 4)
|
| 538 |
+
{
|
| 539 |
+
__half2 tmp_real, tmp_imag;
|
| 540 |
+
|
| 541 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 542 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 543 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 544 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 545 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag;
|
| 546 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 547 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag;
|
| 548 |
+
|
| 549 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 550 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 551 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 552 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 553 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
wmma::fill_fragment(acc_frag_imag, __float2half(0.0f));
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 570 |
+
{
|
| 571 |
+
tmp_real = reinterpret_cast<__half2 *>(acc_frag_real.x)[k];
|
| 572 |
+
tmp_imag = reinterpret_cast<__half2 *>(acc_frag_imag.x)[k];
|
| 573 |
+
reinterpret_cast<__half2 *>(acc_frag_real.x)[k] = __hsub2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]));
|
| 574 |
+
reinterpret_cast<__half2 *>(acc_frag_imag.x)[k] = __hadd2(__hmul2(tmp_real, reinterpret_cast<__half2 *>(tw_frag_imag.x)[k]), __hmul2(tmp_imag, reinterpret_cast<__half2 *>(tw_frag_real.x)[k]));
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 578 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
__syncthreads();
|
| 582 |
+
|
| 583 |
+
#pragma unroll
|
| 584 |
+
for (int i = threadIdx.y; i<N; i+=blockDim.y)
|
| 585 |
+
{
|
| 586 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 587 |
+
out_real[out_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x];
|
| 588 |
+
out_imag[out_offset + idx] = reinterpret_cast<__half2 *>(out_imag_shared)[i * 32 + threadIdx.x];
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
std::vector<torch::Tensor> butterfly_padded_cuda(
|
| 593 |
+
torch::Tensor x,
|
| 594 |
+
torch::Tensor d_f,
|
| 595 |
+
torch::Tensor twiddle_factors_real,
|
| 596 |
+
torch::Tensor twiddle_factors_imag,
|
| 597 |
+
int M,
|
| 598 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 599 |
+
)
|
| 600 |
+
{
|
| 601 |
+
|
| 602 |
+
uint B = x.size(0);
|
| 603 |
+
uint H = x.size(1);
|
| 604 |
+
uint N = x.size(2);
|
| 605 |
+
|
| 606 |
+
uint d_f_size = d_f.size(1);
|
| 607 |
+
|
| 608 |
+
//need to make sure that N is less that the M to which we are padding
|
| 609 |
+
assert(N <= d_f_size * M);
|
| 610 |
+
// printf("B: %d, H: %d, N: %d\n", B, H, N);
|
| 611 |
+
dim3 gridDim;
|
| 612 |
+
dim3 blockDim;
|
| 613 |
+
|
| 614 |
+
gridDim.y = B;
|
| 615 |
+
gridDim.z = H;
|
| 616 |
+
|
| 617 |
+
blockDim.x = 32;
|
| 618 |
+
blockDim.y = 4;
|
| 619 |
+
|
| 620 |
+
torch::Tensor out_real = torch::empty({B, H, d_f_size * M}, x.options());
|
| 621 |
+
torch::Tensor out_imag = torch::empty({B, H, d_f_size * M}, x.options());
|
| 622 |
+
|
| 623 |
+
gridDim.x = 512 / (32 * 1024/ M);
|
| 624 |
+
|
| 625 |
+
const int K = ceil(N / (1.0 * 16 * M));
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
switch(d_f_size){
|
| 629 |
+
case 16:
|
| 630 |
+
butterfly_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 631 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 632 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 633 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 634 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 635 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 636 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 637 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 638 |
+
B,
|
| 639 |
+
H,
|
| 640 |
+
N);
|
| 641 |
+
break;
|
| 642 |
+
case 32:
|
| 643 |
+
switch (K)
|
| 644 |
+
{
|
| 645 |
+
case 1:
|
| 646 |
+
butterfly_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
|
| 647 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 648 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 649 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 650 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 651 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 652 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 653 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 654 |
+
B,
|
| 655 |
+
H,
|
| 656 |
+
N);
|
| 657 |
+
break;
|
| 658 |
+
case 2:
|
| 659 |
+
butterfly_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
|
| 660 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 661 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 662 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 663 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 664 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 665 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 666 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 667 |
+
B,
|
| 668 |
+
H,
|
| 669 |
+
N);
|
| 670 |
+
break;
|
| 671 |
+
default:
|
| 672 |
+
printf("Invalid K, df size 32: %d\n", K);
|
| 673 |
+
}
|
| 674 |
+
break;
|
| 675 |
+
case 64:
|
| 676 |
+
gridDim.z = H / 16;
|
| 677 |
+
|
| 678 |
+
switch (K)
|
| 679 |
+
{
|
| 680 |
+
case 1:
|
| 681 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 682 |
+
butterfly_padded_cuda_kernel_64<1><<<gridDim, blockDim, 65536>>>(
|
| 683 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 684 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 685 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 686 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 687 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 688 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 689 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 690 |
+
B,
|
| 691 |
+
H,
|
| 692 |
+
N);
|
| 693 |
+
break;
|
| 694 |
+
|
| 695 |
+
case 2:
|
| 696 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 697 |
+
butterfly_padded_cuda_kernel_64<2><<<gridDim, blockDim, 65536>>>(
|
| 698 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 699 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 700 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 701 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 702 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 703 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 704 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 705 |
+
B,
|
| 706 |
+
H,
|
| 707 |
+
N);
|
| 708 |
+
break;
|
| 709 |
+
|
| 710 |
+
case 3:
|
| 711 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 712 |
+
butterfly_padded_cuda_kernel_64<3><<<gridDim, blockDim, 65536>>>(
|
| 713 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 714 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 715 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 716 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 717 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 718 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 719 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 720 |
+
B,
|
| 721 |
+
H,
|
| 722 |
+
N);
|
| 723 |
+
break;
|
| 724 |
+
|
| 725 |
+
case 4:
|
| 726 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_64<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 727 |
+
butterfly_padded_cuda_kernel_64<4><<<gridDim, blockDim, 65536>>>(
|
| 728 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 729 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 730 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 731 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 732 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 733 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 734 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 735 |
+
B,
|
| 736 |
+
H,
|
| 737 |
+
N);
|
| 738 |
+
break;
|
| 739 |
+
|
| 740 |
+
default:
|
| 741 |
+
printf("Invalid K, df size 64: %d\n", K);
|
| 742 |
+
}
|
| 743 |
+
break;
|
| 744 |
+
case 128:
|
| 745 |
+
blockDim.x = 32;
|
| 746 |
+
blockDim.y = 8;
|
| 747 |
+
gridDim.x = 256 / (32 * 1024/ M);
|
| 748 |
+
gridDim.z = H / 16;
|
| 749 |
+
|
| 750 |
+
switch(K){
|
| 751 |
+
case 1:
|
| 752 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 753 |
+
butterfly_padded_cuda_kernel_128<1><<<gridDim, blockDim, 65536>>>(
|
| 754 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 755 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 756 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 757 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 758 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 759 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 760 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 761 |
+
B,
|
| 762 |
+
H,
|
| 763 |
+
N);
|
| 764 |
+
break;
|
| 765 |
+
case 2:
|
| 766 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 767 |
+
butterfly_padded_cuda_kernel_128<2><<<gridDim, blockDim, 65536>>>(
|
| 768 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 769 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 770 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 771 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 772 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 773 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 774 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 775 |
+
B,
|
| 776 |
+
H,
|
| 777 |
+
N);
|
| 778 |
+
break;
|
| 779 |
+
case 3:
|
| 780 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 781 |
+
butterfly_padded_cuda_kernel_128<3><<<gridDim, blockDim, 65536>>>(
|
| 782 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 783 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 784 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 785 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 786 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 787 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 788 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 789 |
+
B,
|
| 790 |
+
H,
|
| 791 |
+
N);
|
| 792 |
+
break;
|
| 793 |
+
case 4:
|
| 794 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 795 |
+
butterfly_padded_cuda_kernel_128<4><<<gridDim, blockDim, 65536>>>(
|
| 796 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 797 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 798 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 799 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 800 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 801 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 802 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 803 |
+
B,
|
| 804 |
+
H,
|
| 805 |
+
N);
|
| 806 |
+
break;
|
| 807 |
+
case 5:
|
| 808 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 809 |
+
butterfly_padded_cuda_kernel_128<5><<<gridDim, blockDim, 65536>>>(
|
| 810 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 811 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 812 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 813 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 814 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 815 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 816 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 817 |
+
B,
|
| 818 |
+
H,
|
| 819 |
+
N);
|
| 820 |
+
break;
|
| 821 |
+
case 6:
|
| 822 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 823 |
+
butterfly_padded_cuda_kernel_128<6><<<gridDim, blockDim, 65536>>>(
|
| 824 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 825 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 826 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 827 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 828 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 829 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 830 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 831 |
+
B,
|
| 832 |
+
H,
|
| 833 |
+
N);
|
| 834 |
+
break;
|
| 835 |
+
case 7:
|
| 836 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 837 |
+
butterfly_padded_cuda_kernel_128<7><<<gridDim, blockDim, 65536>>>(
|
| 838 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 839 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 840 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 841 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 842 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 843 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 844 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 845 |
+
B,
|
| 846 |
+
H,
|
| 847 |
+
N);
|
| 848 |
+
break;
|
| 849 |
+
case 8:
|
| 850 |
+
cudaFuncSetAttribute(&butterfly_padded_cuda_kernel_128<8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 851 |
+
butterfly_padded_cuda_kernel_128<8><<<gridDim, blockDim, 65536>>>(
|
| 852 |
+
static_cast<__half2 *>(x.data_ptr()),
|
| 853 |
+
x_gate ? static_cast<__half2 *>(x_gate.value().data_ptr()) : nullptr,
|
| 854 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 855 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 856 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 857 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 858 |
+
static_cast<__half2 *>(out_imag.data_ptr()),
|
| 859 |
+
B,
|
| 860 |
+
H,
|
| 861 |
+
N);
|
| 862 |
+
break;
|
| 863 |
+
default:
|
| 864 |
+
printf("Invalid K, df size 128: %d\n", K);
|
| 865 |
+
}
|
| 866 |
+
break;
|
| 867 |
+
default:
|
| 868 |
+
printf("Invalid d_f size: %d\n", d_f_size);
|
| 869 |
+
}
|
| 870 |
+
return {out_real, out_imag};
|
| 871 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_cuda_bf16.cu
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_runtime.h>
|
| 9 |
+
#include <cuda_fp16.h>
|
| 10 |
+
#include <cuda_bf16.h>
|
| 11 |
+
#include "shared.h"
|
| 12 |
+
|
| 13 |
+
using namespace nvcuda;
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
template <int K>
|
| 17 |
+
__global__ void butterfly_cuda_kernel_64(
|
| 18 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 22 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 23 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 24 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 25 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 26 |
+
uint B,
|
| 27 |
+
uint H,
|
| 28 |
+
int M)
|
| 29 |
+
{
|
| 30 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 31 |
+
const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
|
| 32 |
+
const int out_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * 16 * 64 * 32 * gridDim.x;
|
| 33 |
+
int idx;
|
| 34 |
+
int t_offset;
|
| 35 |
+
int out_t_offset;
|
| 36 |
+
int shared_offset;
|
| 37 |
+
const int N = 64;
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
extern __shared__ __nv_bfloat16 x_shared[];
|
| 41 |
+
__nv_bfloat16 *d_f_real_shared = &x_shared[K * 16 * N];
|
| 42 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 43 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 44 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 45 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 46 |
+
float *out_imag_shared = &out_real_shared[N * N];
|
| 47 |
+
|
| 48 |
+
// #pragma unroll
|
| 49 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 50 |
+
{
|
| 51 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 52 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 53 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 54 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 55 |
+
|
| 56 |
+
// #pragma unroll
|
| 57 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 58 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 59 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
float2 tmp_real, tmp_imag;
|
| 63 |
+
|
| 64 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[4];
|
| 65 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 66 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 67 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[4];
|
| 68 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[4][4];
|
| 69 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[4];
|
| 70 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[4];
|
| 71 |
+
|
| 72 |
+
__syncthreads();
|
| 73 |
+
|
| 74 |
+
for (int i = 0; i < 4; i++)
|
| 75 |
+
{
|
| 76 |
+
wmma::load_matrix_sync(a_frag_real[i], d_f_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 77 |
+
wmma::load_matrix_sync(a_frag_imag[i], d_f_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 78 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 79 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + threadIdx.y * N * 16 + i * 16, N);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
for (int t = 0; t < 16; t++)
|
| 83 |
+
{
|
| 84 |
+
t_offset = t * M/2;
|
| 85 |
+
out_t_offset = t * 64 * 32 * gridDim.x;
|
| 86 |
+
|
| 87 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 88 |
+
{
|
| 89 |
+
if(i < K * 16){
|
| 90 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 91 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 92 |
+
if(x_gate != nullptr){
|
| 93 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 94 |
+
}else{
|
| 95 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
__syncthreads();
|
| 101 |
+
|
| 102 |
+
for (int i = 0; i < K; i++)
|
| 103 |
+
{
|
| 104 |
+
for (int j = 0; j < 4; j++)
|
| 105 |
+
{
|
| 106 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * N * 16 + j * 16, N);
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
#pragma unroll
|
| 111 |
+
for (int j = 0; j < 4; j++)
|
| 112 |
+
{
|
| 113 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 114 |
+
|
| 115 |
+
for (int k = 0; k < K; k++)
|
| 116 |
+
{
|
| 117 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
#pragma unroll
|
| 122 |
+
|
| 123 |
+
for (int j = 0; j < 4; j++)
|
| 124 |
+
{
|
| 125 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 126 |
+
|
| 127 |
+
for (int k = 0; k < K; k++)
|
| 128 |
+
{
|
| 129 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
#pragma unroll
|
| 134 |
+
for (int j = 0; j < 4; j++)
|
| 135 |
+
{
|
| 136 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 137 |
+
{
|
| 138 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 139 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 140 |
+
|
| 141 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 142 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * N * 16 + j * 16, acc_frag_real[j], N, wmma::mem_row_major);
|
| 146 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * N * 16 + j * 16, acc_frag_imag[j], N, wmma::mem_row_major);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
__syncthreads();
|
| 150 |
+
|
| 151 |
+
#pragma unroll
|
| 152 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 153 |
+
{
|
| 154 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 155 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 156 |
+
out_real[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[shared_offset]);
|
| 157 |
+
out_imag[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[shared_offset]);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
__syncthreads();
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template <int K>
|
| 165 |
+
__global__ void butterfly_cuda_kernel_32(
|
| 166 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 167 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 168 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 169 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 170 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 171 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 172 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 173 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 174 |
+
uint B,
|
| 175 |
+
uint H,
|
| 176 |
+
int M)
|
| 177 |
+
{
|
| 178 |
+
const int N = 32;
|
| 179 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 180 |
+
|
| 181 |
+
const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 182 |
+
const int out_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
__shared__ __nv_bfloat16 x_shared[K * 16 * 64];
|
| 186 |
+
__shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
|
| 187 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
|
| 188 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 189 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 190 |
+
__shared__ float out_real_shared[32 * 64];
|
| 191 |
+
__shared__ float out_imag_shared[32 * 64];
|
| 192 |
+
|
| 193 |
+
// #pragma unroll
|
| 194 |
+
for (int i = threadIdx.y; i<32; i+=blockDim.y)
|
| 195 |
+
{
|
| 196 |
+
int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 197 |
+
int shared_offset = i * 32 + threadIdx.x;
|
| 198 |
+
|
| 199 |
+
if(i < K * 16){
|
| 200 |
+
if(x_gate != nullptr){
|
| 201 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 202 |
+
}else{
|
| 203 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 207 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 208 |
+
|
| 209 |
+
// #pragma unroll
|
| 210 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 211 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
__syncthreads();
|
| 215 |
+
|
| 216 |
+
if (threadIdx.y < N / 16)
|
| 217 |
+
{
|
| 218 |
+
float2 tmp_real, tmp_imag;
|
| 219 |
+
|
| 220 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[2][2];
|
| 221 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 222 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 223 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[2][2];
|
| 224 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[K][2];
|
| 225 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[2][2];
|
| 226 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[2][2];
|
| 227 |
+
|
| 228 |
+
int t = threadIdx.y * 32;
|
| 229 |
+
|
| 230 |
+
for (int i = 0; i < 2; i++)
|
| 231 |
+
{
|
| 232 |
+
for (int j = 0; j < 2; j++)
|
| 233 |
+
{
|
| 234 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 235 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 236 |
+
if(i < K){
|
| 237 |
+
wmma::load_matrix_sync(b_frag[i][j], x_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 238 |
+
}
|
| 239 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 240 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
#pragma unroll
|
| 245 |
+
for (int i = 0; i < 2; i++)
|
| 246 |
+
{
|
| 247 |
+
for (int j = 0; j < 2; j++)
|
| 248 |
+
{
|
| 249 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 250 |
+
|
| 251 |
+
for (int k = 0; k < K; k++)
|
| 252 |
+
{
|
| 253 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag[k][j], acc_frag_real[i][j]);
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
#pragma unroll
|
| 259 |
+
for (int i = 0; i < 2; i++)
|
| 260 |
+
{
|
| 261 |
+
for (int j = 0; j < 2; j++)
|
| 262 |
+
{
|
| 263 |
+
wmma::fill_fragment(acc_frag_imag[i][j], 0.0f);
|
| 264 |
+
|
| 265 |
+
for (int k = 0; k < K; k++)
|
| 266 |
+
{
|
| 267 |
+
wmma::mma_sync(acc_frag_imag[i][j], a_frag_imag[i][k], b_frag[k][j], acc_frag_imag[i][j]);
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
#pragma unroll
|
| 273 |
+
for (int i = 0; i < 2; i++)
|
| 274 |
+
{
|
| 275 |
+
for (int j = 0; j < 2; j++)
|
| 276 |
+
{
|
| 277 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements / 2; k++)
|
| 278 |
+
{
|
| 279 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k];
|
| 280 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k];
|
| 281 |
+
reinterpret_cast<float2 *>(acc_frag_real[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]);
|
| 282 |
+
reinterpret_cast<float2 *>(acc_frag_imag[i][j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[i][j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[i][j].x)[k]);
|
| 283 |
+
}
|
| 284 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 285 |
+
wmma::store_matrix_sync(out_imag_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_imag[i][j], 2 * N, wmma::mem_row_major);
|
| 286 |
+
}
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
__syncthreads();
|
| 291 |
+
|
| 292 |
+
#pragma unroll
|
| 293 |
+
for (int i = threadIdx.y; i<32; i+=blockDim.y)
|
| 294 |
+
{
|
| 295 |
+
int idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 296 |
+
out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[i * 32 + threadIdx.x]);
|
| 297 |
+
out_imag[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[i * 32 + threadIdx.x]);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
template <int K>
|
| 302 |
+
__global__ void butterfly_cuda_kernel_128(
|
| 303 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 304 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 305 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 306 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 307 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 308 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 309 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 310 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 311 |
+
uint B,
|
| 312 |
+
uint H,
|
| 313 |
+
int M)
|
| 314 |
+
{
|
| 315 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 316 |
+
const int offset = blockIdx.y * H * M/2 + blockIdx.z * 16 * M/2;
|
| 317 |
+
const int out_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * 16 * 128 * 32 * 2 * gridDim.x;
|
| 318 |
+
const int N = 128;
|
| 319 |
+
int idx;
|
| 320 |
+
int t_offset;
|
| 321 |
+
int out_t_offset;
|
| 322 |
+
int shared_offset;
|
| 323 |
+
|
| 324 |
+
extern __shared__ __nv_bfloat16 shared_real[];
|
| 325 |
+
__nv_bfloat16 *shared_imag = &shared_real[128 * 128];
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[8];
|
| 329 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 330 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 331 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[8];
|
| 332 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag[K][8];
|
| 333 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[8];
|
| 334 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag[8];
|
| 335 |
+
|
| 336 |
+
for (int i = threadIdx.y ; i < N; i+=blockDim.y)
|
| 337 |
+
{
|
| 338 |
+
for(int j=0; j< 2; j++){
|
| 339 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 340 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_real)[shared_offset] = d_f_real[shared_offset];
|
| 341 |
+
reinterpret_cast<__nv_bfloat162 *>(shared_imag)[shared_offset] = d_f_imag[shared_offset];
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
__syncthreads();
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
for (int i = 0; i < 8; i++){
|
| 349 |
+
wmma::load_matrix_sync(a_frag_real[i], shared_real + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 350 |
+
wmma::load_matrix_sync(a_frag_imag[i], shared_imag + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
__syncthreads();
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 359 |
+
{
|
| 360 |
+
for(int j=0; j< 2; j++){
|
| 361 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 362 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 363 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = twiddle_factors_real[idx];
|
| 364 |
+
reinterpret_cast<__nv_bfloat162*>(shared_imag)[shared_offset] = twiddle_factors_imag[idx];
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
__syncthreads();
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
for (int i = 0; i < 8; i++){
|
| 372 |
+
wmma::load_matrix_sync(tw_frag_real[i], shared_real + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 373 |
+
wmma::load_matrix_sync(tw_frag_imag[i], shared_imag + threadIdx.y * 128 * 16 + i * 16, 128);
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
__syncthreads();
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
for(int t=0; t< 16; t++){
|
| 380 |
+
t_offset = t * M/2;
|
| 381 |
+
out_t_offset = t * 128 * 32 * 2 * gridDim.x;
|
| 382 |
+
|
| 383 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 384 |
+
{
|
| 385 |
+
if(i < K * 16){
|
| 386 |
+
for(int j=0; j< 2; j++){
|
| 387 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 388 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 389 |
+
if(x_gate != nullptr){
|
| 390 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset + t_offset], x_gate[idx + offset + t_offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 391 |
+
}else{
|
| 392 |
+
reinterpret_cast<__nv_bfloat162*>(shared_real)[shared_offset] = idx < max_idx ? x[idx + offset + t_offset] : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
__syncthreads();
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
for (int i = 0; i < K; i++)
|
| 403 |
+
{
|
| 404 |
+
for (int j = 0; j < 8; j++)
|
| 405 |
+
{
|
| 406 |
+
wmma::load_matrix_sync(b_frag[i][j], shared_real + i * 128 * 16 + j * 16, 128);
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
__syncthreads();
|
| 411 |
+
|
| 412 |
+
#pragma unroll
|
| 413 |
+
for (int j = 0; j < 8; j++)
|
| 414 |
+
{
|
| 415 |
+
wmma::fill_fragment(acc_frag_real[j], 0.0f);
|
| 416 |
+
|
| 417 |
+
for (int k = 0; k < K; k++)
|
| 418 |
+
{
|
| 419 |
+
wmma::mma_sync(acc_frag_real[j], a_frag_real[k], b_frag[k][j], acc_frag_real[j]);
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
#pragma unroll
|
| 424 |
+
|
| 425 |
+
for (int j = 0; j < 8; j++)
|
| 426 |
+
{
|
| 427 |
+
wmma::fill_fragment(acc_frag_imag[j], 0.0f);
|
| 428 |
+
|
| 429 |
+
for (int k = 0; k < K; k++)
|
| 430 |
+
{
|
| 431 |
+
wmma::mma_sync(acc_frag_imag[j], a_frag_imag[k], b_frag[k][j], acc_frag_imag[j]);
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
float2 tmp_real, tmp_imag;
|
| 436 |
+
#pragma unroll
|
| 437 |
+
for (int j = 0; j < 8; j++)
|
| 438 |
+
{
|
| 439 |
+
for (int k = 0; k < acc_frag_real[j].num_elements / 2; k++)
|
| 440 |
+
{
|
| 441 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real[j].x)[k];
|
| 442 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k];
|
| 443 |
+
|
| 444 |
+
reinterpret_cast<float2 *>(acc_frag_real[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]);
|
| 445 |
+
reinterpret_cast<float2 *>(acc_frag_imag[j].x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag[j].x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real[j].x)[k]);
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
for (int j = 0; j < 8; j++)
|
| 450 |
+
{
|
| 451 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_real[j], 128, wmma::mem_row_major);
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
__syncthreads();
|
| 455 |
+
|
| 456 |
+
#pragma unroll
|
| 457 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 458 |
+
{
|
| 459 |
+
for(int j=0; j< 2; j++){
|
| 460 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 461 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 462 |
+
out_real[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 463 |
+
}
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
__syncthreads();
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
for (int j = 0; j < 8; j++)
|
| 470 |
+
{
|
| 471 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(shared_real) + threadIdx.y * 128 * 16 + j * 16, acc_frag_imag[j], 128, wmma::mem_row_major);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
__syncthreads();
|
| 475 |
+
|
| 476 |
+
#pragma unroll
|
| 477 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 478 |
+
{
|
| 479 |
+
for(int j=0; j< 2; j++){
|
| 480 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 481 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 482 |
+
out_imag[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(shared_real)[shared_offset]);
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
template<int K>
|
| 489 |
+
__global__ void butterfly_cuda_kernel_16(
|
| 490 |
+
const __nv_bfloat162 *__restrict__ x,
|
| 491 |
+
const __nv_bfloat162 *__restrict__ x_gate,
|
| 492 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 493 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 494 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 495 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 496 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 497 |
+
__nv_bfloat162 *__restrict__ out_imag,
|
| 498 |
+
uint B,
|
| 499 |
+
uint H,
|
| 500 |
+
int M)
|
| 501 |
+
{
|
| 502 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 503 |
+
const int N = 16;
|
| 504 |
+
const int offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 505 |
+
const int out_offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
__shared__ __nv_bfloat16 x_shared[N * 64];
|
| 510 |
+
__shared__ __nv_bfloat16 d_f_real_shared[N * N];
|
| 511 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[N * N];
|
| 512 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[N * 64];
|
| 513 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[N * 64];
|
| 514 |
+
__shared__ float out_real_shared[N * 64];
|
| 515 |
+
__shared__ float out_imag_shared[N * 64];
|
| 516 |
+
|
| 517 |
+
// #pragma unroll
|
| 518 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 519 |
+
{
|
| 520 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 521 |
+
int shared_offset = i * blockDim.x + threadIdx.x;
|
| 522 |
+
|
| 523 |
+
if(x_gate != nullptr){
|
| 524 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? __hmul2(x[idx + offset], x_gate[idx + offset]) : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 525 |
+
}else{
|
| 526 |
+
reinterpret_cast<__nv_bfloat162 *>(x_shared)[shared_offset] = idx < max_idx ? x[idx + offset] : __floats2bfloat162_rn(0.0f, 0.0f);
|
| 527 |
+
}
|
| 528 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 529 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 530 |
+
|
| 531 |
+
// #pragma unroll
|
| 532 |
+
if(threadIdx.x < 16 ){
|
| 533 |
+
shared_offset = i * 16 + threadIdx.x;
|
| 534 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 535 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
__syncthreads();
|
| 540 |
+
|
| 541 |
+
if (threadIdx.y < 4)
|
| 542 |
+
{
|
| 543 |
+
float2 tmp_real, tmp_imag;
|
| 544 |
+
|
| 545 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 546 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 547 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 548 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 549 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag;
|
| 550 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 551 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_imag;
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real_shared, N);
|
| 555 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag_shared, N);
|
| 556 |
+
wmma::load_matrix_sync(b_frag, x_shared + threadIdx.y * 16, 64);
|
| 557 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 558 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag, acc_frag_real);
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
wmma::fill_fragment(acc_frag_imag, 0.0f);
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
wmma::mma_sync(acc_frag_imag, a_frag_imag, b_frag, acc_frag_imag);
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
#pragma unroll
|
| 576 |
+
for (int k = 0; k < acc_frag_real.num_elements / 2; k++)
|
| 577 |
+
{
|
| 578 |
+
tmp_real = reinterpret_cast<float2 *>(acc_frag_real.x)[k];
|
| 579 |
+
tmp_imag = reinterpret_cast<float2 *>(acc_frag_imag.x)[k];
|
| 580 |
+
reinterpret_cast<float2 *>(acc_frag_real.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]) - tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]);
|
| 581 |
+
reinterpret_cast<float2 *>(acc_frag_imag.x)[k] = tmp_real * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_imag.x)[k]) + tmp_imag * __bfloat1622float2(reinterpret_cast<__nv_bfloat162 *>(tw_frag_real.x)[k]);
|
| 582 |
+
}
|
| 583 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 584 |
+
wmma::store_matrix_sync(out_imag_shared + threadIdx.y * 16, acc_frag_imag, 64, wmma::mem_row_major);
|
| 585 |
+
|
| 586 |
+
}
|
| 587 |
+
__syncthreads();
|
| 588 |
+
|
| 589 |
+
#pragma unroll
|
| 590 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 591 |
+
{
|
| 592 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;;
|
| 593 |
+
out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_real_shared)[i * 32 + threadIdx.x]);
|
| 594 |
+
out_imag[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2*>(out_imag_shared)[i * 32 + threadIdx.x]);
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
std::vector<torch::Tensor> butterfly_padded_bf16_cuda(
|
| 599 |
+
torch::Tensor x,
|
| 600 |
+
torch::Tensor d_f_real,
|
| 601 |
+
torch::Tensor d_f_imag,
|
| 602 |
+
torch::Tensor twiddle_factors_real,
|
| 603 |
+
torch::Tensor twiddle_factors_imag,
|
| 604 |
+
int M,
|
| 605 |
+
std::optional<at::Tensor> x_gate = std::nullopt
|
| 606 |
+
)
|
| 607 |
+
{
|
| 608 |
+
|
| 609 |
+
uint B = x.size(0);
|
| 610 |
+
uint H = x.size(1);
|
| 611 |
+
|
| 612 |
+
uint d_f_size = d_f_real.size(1);
|
| 613 |
+
|
| 614 |
+
uint N = x.size(2);
|
| 615 |
+
|
| 616 |
+
//need to make sure that N is less that the M to which we are padding
|
| 617 |
+
assert(N <= d_f_size * M);
|
| 618 |
+
|
| 619 |
+
dim3 gridDim;
|
| 620 |
+
dim3 blockDim;
|
| 621 |
+
|
| 622 |
+
gridDim.y = B;
|
| 623 |
+
gridDim.z = H;
|
| 624 |
+
|
| 625 |
+
blockDim.x = 32;
|
| 626 |
+
blockDim.y = 4;
|
| 627 |
+
|
| 628 |
+
torch::Tensor out_real = torch::empty({B, H, d_f_size * M}, x.options());
|
| 629 |
+
torch::Tensor out_imag = torch::empty({B, H, d_f_size * M}, x.options());
|
| 630 |
+
|
| 631 |
+
gridDim.x = 512 / (32 * 1024/ M);
|
| 632 |
+
|
| 633 |
+
const int K = ceil(N / (1.0 * 16 * M));
|
| 634 |
+
|
| 635 |
+
switch (d_f_size)
|
| 636 |
+
{
|
| 637 |
+
case 16:
|
| 638 |
+
butterfly_cuda_kernel_16<1><<<gridDim, blockDim>>>(
|
| 639 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 640 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 641 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 642 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 643 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 644 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 645 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 646 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 647 |
+
B,
|
| 648 |
+
H,
|
| 649 |
+
N);
|
| 650 |
+
break;
|
| 651 |
+
case 32:
|
| 652 |
+
switch(K){
|
| 653 |
+
case 1:
|
| 654 |
+
butterfly_cuda_kernel_32<1><<<gridDim, blockDim>>>(
|
| 655 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 656 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 657 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 658 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 659 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 661 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 662 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 663 |
+
B,
|
| 664 |
+
H,
|
| 665 |
+
N);
|
| 666 |
+
break;
|
| 667 |
+
case 2:
|
| 668 |
+
butterfly_cuda_kernel_32<2><<<gridDim, blockDim>>>(
|
| 669 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 670 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 671 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 672 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 673 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 676 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 677 |
+
B,
|
| 678 |
+
H,
|
| 679 |
+
N);
|
| 680 |
+
break;
|
| 681 |
+
default:
|
| 682 |
+
printf("Invalid K, df size 32: %d\n", K);
|
| 683 |
+
}
|
| 684 |
+
break;
|
| 685 |
+
case 64:
|
| 686 |
+
gridDim.z = H / 16;
|
| 687 |
+
|
| 688 |
+
switch(K){
|
| 689 |
+
case 1:
|
| 690 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 691 |
+
butterfly_cuda_kernel_64<1><<<gridDim, blockDim, 78000>>>(
|
| 692 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 693 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 694 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 695 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 696 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 697 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 698 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 699 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 700 |
+
B,
|
| 701 |
+
H,
|
| 702 |
+
N);
|
| 703 |
+
break;
|
| 704 |
+
case 2:
|
| 705 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 706 |
+
butterfly_cuda_kernel_64<2><<<gridDim, blockDim, 78000>>>(
|
| 707 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 708 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 709 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 710 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 711 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 712 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 713 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 714 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 715 |
+
B,
|
| 716 |
+
H,
|
| 717 |
+
N);
|
| 718 |
+
break;
|
| 719 |
+
case 3:
|
| 720 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 721 |
+
butterfly_cuda_kernel_64<3><<<gridDim, blockDim, 78000>>>(
|
| 722 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 723 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 724 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 725 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 726 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 727 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 728 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 729 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 730 |
+
B,
|
| 731 |
+
H,
|
| 732 |
+
N);
|
| 733 |
+
break;
|
| 734 |
+
case 4:
|
| 735 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_64<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 78000);
|
| 736 |
+
butterfly_cuda_kernel_64<4><<<gridDim, blockDim, 78000>>>(
|
| 737 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 738 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 739 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 740 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 741 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 742 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 743 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 744 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 745 |
+
B,
|
| 746 |
+
H,
|
| 747 |
+
N);
|
| 748 |
+
break;
|
| 749 |
+
default:
|
| 750 |
+
printf("Invalid K, df size 64: %d\n", K);
|
| 751 |
+
}
|
| 752 |
+
break;
|
| 753 |
+
case 128:
|
| 754 |
+
blockDim.x = 32;
|
| 755 |
+
blockDim.y = 8;
|
| 756 |
+
gridDim.x = 256 / (32 * 1024/ M);
|
| 757 |
+
gridDim.z = H / 16;
|
| 758 |
+
switch(K){
|
| 759 |
+
case 1:
|
| 760 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 761 |
+
butterfly_cuda_kernel_128<1><<<gridDim, blockDim, 65536>>>(
|
| 762 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 763 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 764 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 765 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 766 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 767 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 768 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 769 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 770 |
+
B,
|
| 771 |
+
H,
|
| 772 |
+
N);
|
| 773 |
+
break;
|
| 774 |
+
case 2:
|
| 775 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 776 |
+
butterfly_cuda_kernel_128<2><<<gridDim, blockDim, 65536>>>(
|
| 777 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 778 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 779 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 780 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 781 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 782 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 783 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 784 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 785 |
+
B,
|
| 786 |
+
H,
|
| 787 |
+
N);
|
| 788 |
+
break;
|
| 789 |
+
case 3:
|
| 790 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 791 |
+
|
| 792 |
+
butterfly_cuda_kernel_128<3><<<gridDim, blockDim, 65536>>>(
|
| 793 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 794 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 795 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 796 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 797 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 798 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 799 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 800 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 801 |
+
B,
|
| 802 |
+
H,
|
| 803 |
+
N);
|
| 804 |
+
break;
|
| 805 |
+
case 4:
|
| 806 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 807 |
+
|
| 808 |
+
butterfly_cuda_kernel_128<4><<<gridDim, blockDim, 65536>>>(
|
| 809 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 810 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 811 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 812 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 813 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 814 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 815 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 816 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 817 |
+
B,
|
| 818 |
+
H,
|
| 819 |
+
N);
|
| 820 |
+
break;
|
| 821 |
+
case 5:
|
| 822 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 823 |
+
|
| 824 |
+
butterfly_cuda_kernel_128<5><<<gridDim, blockDim, 65536>>>(
|
| 825 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 826 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 827 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 828 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 829 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 830 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 831 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 832 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 833 |
+
B,
|
| 834 |
+
H,
|
| 835 |
+
N);
|
| 836 |
+
break;
|
| 837 |
+
case 6:
|
| 838 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 839 |
+
|
| 840 |
+
butterfly_cuda_kernel_128<6><<<gridDim, blockDim, 65536>>>(
|
| 841 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 842 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 843 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 844 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 845 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 846 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 847 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 848 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 849 |
+
B,
|
| 850 |
+
H,
|
| 851 |
+
N);
|
| 852 |
+
break;
|
| 853 |
+
case 7:
|
| 854 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 855 |
+
|
| 856 |
+
butterfly_cuda_kernel_128<7><<<gridDim, blockDim, 65536>>>(
|
| 857 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 858 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 859 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 860 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 861 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 862 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 863 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 864 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 865 |
+
B,
|
| 866 |
+
H,
|
| 867 |
+
N);
|
| 868 |
+
break;
|
| 869 |
+
case 8:
|
| 870 |
+
cudaFuncSetAttribute(&butterfly_cuda_kernel_128<8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 871 |
+
|
| 872 |
+
butterfly_cuda_kernel_128<8><<<gridDim, blockDim, 65536>>>(
|
| 873 |
+
static_cast<__nv_bfloat162 *>(x.data_ptr()),
|
| 874 |
+
x_gate ? static_cast<__nv_bfloat162 *>(x_gate.value().data_ptr()) : nullptr,
|
| 875 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 876 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 877 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 878 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 879 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 880 |
+
static_cast<__nv_bfloat162 *>(out_imag.data_ptr()),
|
| 881 |
+
B,
|
| 882 |
+
H,
|
| 883 |
+
N);
|
| 884 |
+
break;
|
| 885 |
+
default:
|
| 886 |
+
printf("Invalid K, df size 128: %d\n", K);
|
| 887 |
+
|
| 888 |
+
}
|
| 889 |
+
break;
|
| 890 |
+
|
| 891 |
+
default:
|
| 892 |
+
printf("Not yet implemented \n");
|
| 893 |
+
break;
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
return {out_real, out_imag};
|
| 897 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda.cu
ADDED
|
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
template <int TILE_H, int K>
|
| 15 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_64(
|
| 16 |
+
const __half2 *__restrict__ x_real,
|
| 17 |
+
const __half2 *__restrict__ x_imag,
|
| 18 |
+
const complex_half_t *__restrict__ d_f,
|
| 19 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 20 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 21 |
+
__half2 *__restrict__ out_real,
|
| 22 |
+
__half2 *__restrict__ out_gate,
|
| 23 |
+
uint B,
|
| 24 |
+
uint H,
|
| 25 |
+
int M)
|
| 26 |
+
{
|
| 27 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 28 |
+
const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
|
| 29 |
+
const int in_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * TILE_H * 64 * 32 * gridDim.x;
|
| 30 |
+
int idx;
|
| 31 |
+
int t_offset;
|
| 32 |
+
int out_t_offset;
|
| 33 |
+
int shared_offset;
|
| 34 |
+
const int N = 64;
|
| 35 |
+
|
| 36 |
+
extern __shared__ half x_real_shared[];
|
| 37 |
+
half *x_imag_shared = &x_real_shared[N * N];
|
| 38 |
+
half *d_f_real = &x_imag_shared[N * N];
|
| 39 |
+
half *d_f_imag = &d_f_real[N * N];
|
| 40 |
+
half *twiddles_real_shared = &d_f_imag[N * N];
|
| 41 |
+
half *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 42 |
+
half *out_real_shared = &twiddles_imag_shared[N * N];
|
| 43 |
+
|
| 44 |
+
half tmp_real, tmp_imag;
|
| 45 |
+
|
| 46 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[K][4];
|
| 47 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[K][4];
|
| 48 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
|
| 49 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
|
| 50 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
|
| 51 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
|
| 52 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K];
|
| 53 |
+
|
| 54 |
+
// #pragma unroll
|
| 55 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 56 |
+
{
|
| 57 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 58 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 59 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 60 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 61 |
+
|
| 62 |
+
// #pragma unroll
|
| 63 |
+
shared_offset = i * 64 + threadIdx.x;
|
| 64 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 65 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 66 |
+
|
| 67 |
+
d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
|
| 68 |
+
d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
__syncthreads();
|
| 72 |
+
|
| 73 |
+
for (int i = 0; i < 4; i++)
|
| 74 |
+
{
|
| 75 |
+
if(i < K){
|
| 76 |
+
#pragma unroll
|
| 77 |
+
for (int j = 0; j < 4; j++)
|
| 78 |
+
{
|
| 79 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 80 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 84 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
for (int t = 0; t < TILE_H; t++)
|
| 88 |
+
{
|
| 89 |
+
|
| 90 |
+
out_t_offset = t * M/2;
|
| 91 |
+
t_offset = t * 64 * 32 * gridDim.x;
|
| 92 |
+
|
| 93 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 94 |
+
{
|
| 95 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 96 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 97 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
|
| 98 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
__syncthreads();
|
| 102 |
+
|
| 103 |
+
for (int i = 0; i < 4; i++)
|
| 104 |
+
{
|
| 105 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 106 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
for (int j = 0; j < 4; j++)
|
| 110 |
+
{
|
| 111 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 112 |
+
{
|
| 113 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 114 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 115 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 116 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
for (int i = 0; i < K; i++)
|
| 121 |
+
{
|
| 122 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 123 |
+
|
| 124 |
+
// bd
|
| 125 |
+
#pragma unroll
|
| 126 |
+
for (int k = 0; k < 4; k++)
|
| 127 |
+
{
|
| 128 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 132 |
+
{
|
| 133 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
for (int i = 0; i < K; i++)
|
| 138 |
+
{
|
| 139 |
+
// ac - bd
|
| 140 |
+
#pragma unroll
|
| 141 |
+
for (int k = 0; k < 4; k++)
|
| 142 |
+
{
|
| 143 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
#pragma unroll
|
| 148 |
+
for (int i = 0; i < K; i++)
|
| 149 |
+
{
|
| 150 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
__syncthreads();
|
| 154 |
+
|
| 155 |
+
#pragma unroll
|
| 156 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 157 |
+
{
|
| 158 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 159 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 160 |
+
|
| 161 |
+
if(idx < max_idx){
|
| 162 |
+
if(out_gate != nullptr)
|
| 163 |
+
out_real[out_offset + out_t_offset + idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[shared_offset], out_gate[out_offset + out_t_offset + idx]);
|
| 164 |
+
else
|
| 165 |
+
out_real[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
__syncthreads();
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
template <int K>
|
| 175 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_32(
|
| 176 |
+
const __half2 *__restrict__ x_real,
|
| 177 |
+
const __half2 *__restrict__ x_imag,
|
| 178 |
+
const complex_half_t *__restrict__ d_f,
|
| 179 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 180 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 181 |
+
__half2 *__restrict__ out_real,
|
| 182 |
+
__half2 *__restrict__ out_gate,
|
| 183 |
+
uint B,
|
| 184 |
+
uint H,
|
| 185 |
+
int M)
|
| 186 |
+
{
|
| 187 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 188 |
+
const int N = 32;
|
| 189 |
+
int idx;
|
| 190 |
+
int shared_offset;
|
| 191 |
+
|
| 192 |
+
const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 193 |
+
const int in_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
__shared__ half x_real_shared[32 * 64];
|
| 197 |
+
__shared__ half x_imag_shared[32 * 64];
|
| 198 |
+
__shared__ half d_f_real[32 * 32];
|
| 199 |
+
__shared__ half d_f_imag[32 * 32];
|
| 200 |
+
__shared__ half twiddles_real_shared[32 * 64];
|
| 201 |
+
__shared__ half twiddles_imag_shared[32 * 64];
|
| 202 |
+
__shared__ half out_real_shared[32 * 64];
|
| 203 |
+
|
| 204 |
+
// #pragma unroll
|
| 205 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 206 |
+
{
|
| 207 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 208 |
+
int shared_offset = i * 32 + threadIdx.x;
|
| 209 |
+
|
| 210 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[in_offset + idx];
|
| 211 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[in_offset + idx];
|
| 212 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 213 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 214 |
+
|
| 215 |
+
// #pragma unroll
|
| 216 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 217 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 218 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
__syncthreads();
|
| 222 |
+
|
| 223 |
+
if (threadIdx.y < N/16)
|
| 224 |
+
{
|
| 225 |
+
half tmp_real, tmp_imag;
|
| 226 |
+
|
| 227 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[K][2];
|
| 228 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[K][2];
|
| 229 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
|
| 230 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
|
| 231 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
|
| 232 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
|
| 233 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K][2];
|
| 234 |
+
|
| 235 |
+
int t = threadIdx.y * 32;
|
| 236 |
+
|
| 237 |
+
for (int i = 0; i < 2; i++)
|
| 238 |
+
{
|
| 239 |
+
for (int j = 0; j < 2; j++)
|
| 240 |
+
{
|
| 241 |
+
if(i < K){
|
| 242 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
|
| 243 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
|
| 244 |
+
}
|
| 245 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 246 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 247 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 248 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
for (int i = 0; i < 2; i++)
|
| 253 |
+
{
|
| 254 |
+
for (int j = 0; j < 2; j++)
|
| 255 |
+
{
|
| 256 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 257 |
+
{
|
| 258 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 259 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 260 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 261 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
for (int i = 0; i < K; i++)
|
| 267 |
+
{
|
| 268 |
+
for (int j = 0; j < 2; j++)
|
| 269 |
+
{
|
| 270 |
+
wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));
|
| 271 |
+
|
| 272 |
+
// bd
|
| 273 |
+
for (int k = 0; k < 2; k++)
|
| 274 |
+
{
|
| 275 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 279 |
+
{
|
| 280 |
+
acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
for (int i = 0; i < K; i++)
|
| 286 |
+
{
|
| 287 |
+
for (int j = 0; j < 2; j++)
|
| 288 |
+
{
|
| 289 |
+
// ac - bd
|
| 290 |
+
for (int k = 0; k < 2; k++)
|
| 291 |
+
{
|
| 292 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
for (int i = 0; i < K; i++)
|
| 298 |
+
{
|
| 299 |
+
for (int j = 0; j < 2; j++)
|
| 300 |
+
{
|
| 301 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
__syncthreads();
|
| 307 |
+
|
| 308 |
+
#pragma unroll
|
| 309 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 310 |
+
{
|
| 311 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 312 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 313 |
+
|
| 314 |
+
if(idx < max_idx){
|
| 315 |
+
if(out_gate != nullptr){
|
| 316 |
+
out_real[idx + out_offset] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[shared_offset], out_gate[idx + out_offset]);
|
| 317 |
+
}else{
|
| 318 |
+
out_real[idx + out_offset] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
|
| 319 |
+
}
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
}
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
template <int TILE_H, int K>
|
| 327 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_128(
|
| 328 |
+
const __half2 *__restrict__ x_real,
|
| 329 |
+
const __half2 *__restrict__ x_imag,
|
| 330 |
+
const complex_half_t *__restrict__ d_f,
|
| 331 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 332 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 333 |
+
__half2 *__restrict__ out_real,
|
| 334 |
+
__half2 *__restrict__ out_gate,
|
| 335 |
+
uint B,
|
| 336 |
+
uint H,
|
| 337 |
+
int M)
|
| 338 |
+
{
|
| 339 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 340 |
+
const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
|
| 341 |
+
const int in_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * TILE_H * 128 * 32 * 2 * gridDim.x;
|
| 342 |
+
const int N = 128;
|
| 343 |
+
int idx;
|
| 344 |
+
int t_offset;
|
| 345 |
+
int out_t_offset;
|
| 346 |
+
int shared_offset;
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
extern __shared__ half real_shared[];
|
| 350 |
+
half *imag_shared = &real_shared[128 * 128];
|
| 351 |
+
half *real_shared_2 = &imag_shared[128 * 128];
|
| 352 |
+
half *imag_shared_2 = &real_shared_2[128 * 128];
|
| 353 |
+
|
| 354 |
+
half tmp_real, tmp_imag;
|
| 355 |
+
|
| 356 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag[K][8];
|
| 357 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
|
| 358 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
|
| 359 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
|
| 360 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
|
| 361 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[K];
|
| 362 |
+
|
| 363 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 364 |
+
{
|
| 365 |
+
for(int j=0; j< 4; j++){
|
| 366 |
+
shared_offset = i * 128 + threadIdx.x + j * blockDim.x;
|
| 367 |
+
real_shared_2[shared_offset] = d_f[shared_offset].real();
|
| 368 |
+
imag_shared_2[shared_offset] = d_f[shared_offset].imag();
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 373 |
+
{
|
| 374 |
+
for(int j=0; j< 2; j++){
|
| 375 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 376 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 377 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 378 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
__syncthreads();
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
for (int i = 0; i < 8; i++){
|
| 386 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 387 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
__syncthreads();
|
| 391 |
+
|
| 392 |
+
for (int t = 0; t < TILE_H; t++)
|
| 393 |
+
{
|
| 394 |
+
|
| 395 |
+
out_t_offset = t * M/2;
|
| 396 |
+
t_offset = t * 128 * 32 * 2 * gridDim.x;
|
| 397 |
+
|
| 398 |
+
for (int i = 0; i < K; i++){
|
| 399 |
+
for (int j = 0; j < 8; j++){
|
| 400 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 405 |
+
{
|
| 406 |
+
for(int j=0; j< 2; j++){
|
| 407 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 408 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 409 |
+
reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
|
| 410 |
+
reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
__syncthreads();
|
| 415 |
+
|
| 416 |
+
for (int i = 0; i < 8; i++)
|
| 417 |
+
{
|
| 418 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 419 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
for (int j = 0; j < 8; j++)
|
| 424 |
+
{
|
| 425 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 426 |
+
{
|
| 427 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 428 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 429 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 430 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
for (int i = 0; i < K; i++)
|
| 435 |
+
{
|
| 436 |
+
wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));
|
| 437 |
+
|
| 438 |
+
// bd
|
| 439 |
+
#pragma unroll
|
| 440 |
+
for (int k = 0; k < 8; k++)
|
| 441 |
+
{
|
| 442 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 446 |
+
{
|
| 447 |
+
acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
for (int i = 0; i < K; i++){
|
| 452 |
+
for (int j = 0; j < 8; j++){
|
| 453 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
for (int i = 0; i < K; i++)
|
| 458 |
+
{
|
| 459 |
+
// ac - bd
|
| 460 |
+
#pragma unroll
|
| 461 |
+
for (int k = 0; k < 8; k++)
|
| 462 |
+
{
|
| 463 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 464 |
+
}
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
#pragma unroll
|
| 468 |
+
for (int i = 0; i < K; i++)
|
| 469 |
+
{
|
| 470 |
+
//wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 471 |
+
wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
__syncthreads();
|
| 475 |
+
|
| 476 |
+
#pragma unroll
|
| 477 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 478 |
+
{
|
| 479 |
+
for(int j=0; j< 2; j++){
|
| 480 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 481 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 482 |
+
if(idx < max_idx){
|
| 483 |
+
if(out_gate != nullptr){
|
| 484 |
+
out_real[idx + out_offset + out_t_offset] = __hmul2(reinterpret_cast<__half2*>(real_shared)[shared_offset], out_gate[idx + out_offset + out_t_offset]);
|
| 485 |
+
}else{
|
| 486 |
+
out_real[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
__syncthreads();
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_16(
|
| 498 |
+
const __half2 *__restrict__ x_real,
|
| 499 |
+
const __half2 *__restrict__ x_imag,
|
| 500 |
+
const complex_half_t *__restrict__ d_f,
|
| 501 |
+
const __half2 *__restrict__ twiddle_factors_real,
|
| 502 |
+
const __half2 *__restrict__ twiddle_factors_imag,
|
| 503 |
+
__half2 *__restrict__ out_real,
|
| 504 |
+
__half2 *__restrict__ out_gate,
|
| 505 |
+
uint B,
|
| 506 |
+
uint H,
|
| 507 |
+
int M)
|
| 508 |
+
{
|
| 509 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 510 |
+
const int N = 16;
|
| 511 |
+
const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 512 |
+
const int offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
|
| 513 |
+
|
| 514 |
+
__shared__ half x_real_shared[N * 64];
|
| 515 |
+
__shared__ half x_imag_shared[N * 64];
|
| 516 |
+
__shared__ half d_f_real[N * N];
|
| 517 |
+
__shared__ half d_f_imag[N * N];
|
| 518 |
+
__shared__ half twiddles_real_shared[N * 64];
|
| 519 |
+
__shared__ half twiddles_imag_shared[N * 64];
|
| 520 |
+
__shared__ half out_real_shared[N * 64];
|
| 521 |
+
|
| 522 |
+
// #pragma unroll
|
| 523 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 524 |
+
{
|
| 525 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 526 |
+
int shared_offset = i * blockDim.x + threadIdx.x;
|
| 527 |
+
reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 528 |
+
reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 529 |
+
reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 530 |
+
reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 531 |
+
|
| 532 |
+
if(threadIdx.x < 16 ){
|
| 533 |
+
shared_offset = i * 16 + threadIdx.x;
|
| 534 |
+
d_f_real[shared_offset] = d_f[shared_offset].real();
|
| 535 |
+
d_f_imag[shared_offset] = d_f[shared_offset].imag();
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
__syncthreads();
|
| 540 |
+
|
| 541 |
+
//check if it is better to have one warp do all the multiplication or split between warps
|
| 542 |
+
if (threadIdx.y < 4)
|
| 543 |
+
{
|
| 544 |
+
half tmp_real, tmp_imag;
|
| 545 |
+
|
| 546 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real;
|
| 547 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag;
|
| 548 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real;
|
| 549 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag;
|
| 550 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real;
|
| 551 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag;
|
| 552 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real;
|
| 553 |
+
|
| 554 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 555 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 556 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 557 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 558 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 559 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 564 |
+
{
|
| 565 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 566 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 567 |
+
b_frag_real.x[k] = tmp_real;
|
| 568 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
wmma::fill_fragment(acc_frag_real, __float2half(0.0f));
|
| 573 |
+
|
| 574 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 575 |
+
|
| 576 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 577 |
+
acc_frag_real.x[k] = __hneg(acc_frag_real.x[k]);
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 582 |
+
|
| 583 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 584 |
+
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
__syncthreads();
|
| 588 |
+
|
| 589 |
+
#pragma unroll
|
| 590 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 591 |
+
{
|
| 592 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 593 |
+
if(idx < max_idx){
|
| 594 |
+
if(out_gate != nullptr){
|
| 595 |
+
out_real[out_offset + idx] = __hmul2(reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x], out_gate[out_offset + idx]);
|
| 596 |
+
}
|
| 597 |
+
else{
|
| 598 |
+
out_real[out_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[i * 32 + threadIdx.x];
|
| 599 |
+
}
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
torch::Tensor butterfly_ifft_padded_cuda(
|
| 605 |
+
torch::Tensor x_real,
|
| 606 |
+
torch::Tensor x_imag,
|
| 607 |
+
torch::Tensor d_f,
|
| 608 |
+
torch::Tensor twiddle_factors_real,
|
| 609 |
+
torch::Tensor twiddle_factors_imag,
|
| 610 |
+
int fft_size,
|
| 611 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 612 |
+
)
|
| 613 |
+
{
|
| 614 |
+
|
| 615 |
+
uint B = x_real.size(0);
|
| 616 |
+
uint H = x_real.size(1);
|
| 617 |
+
uint N_M = x_real.size(2);
|
| 618 |
+
const int d_f_size = d_f.size(0);
|
| 619 |
+
// const int TILE_SIZE = 16;
|
| 620 |
+
|
| 621 |
+
dim3 gridDim;
|
| 622 |
+
dim3 blockDim;
|
| 623 |
+
|
| 624 |
+
// uint N = x_real.size(2);
|
| 625 |
+
gridDim.y = B;
|
| 626 |
+
|
| 627 |
+
blockDim.x = 32;
|
| 628 |
+
blockDim.y = 4;
|
| 629 |
+
gridDim.x = 512 / (32 * 1024/ (N_M / d_f_size));
|
| 630 |
+
gridDim.z = H;
|
| 631 |
+
|
| 632 |
+
const int TILE_H = 16;
|
| 633 |
+
torch::Tensor out_real = torch::empty({B, H, fft_size}, x_real.options());
|
| 634 |
+
const int K = ceil(fft_size / (1.0 * 16 * (N_M / d_f_size)));
|
| 635 |
+
|
| 636 |
+
switch(d_f_size){
|
| 637 |
+
case 16:
|
| 638 |
+
butterfly_ifft_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 639 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 640 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 641 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 642 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 643 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 644 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 645 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 646 |
+
B,
|
| 647 |
+
H,
|
| 648 |
+
fft_size
|
| 649 |
+
);
|
| 650 |
+
break;
|
| 651 |
+
case 32:
|
| 652 |
+
switch (K)
|
| 653 |
+
{
|
| 654 |
+
case 1:
|
| 655 |
+
butterfly_ifft_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
|
| 656 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 657 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 658 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 659 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 660 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 661 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 662 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 663 |
+
B,
|
| 664 |
+
H,
|
| 665 |
+
fft_size
|
| 666 |
+
);
|
| 667 |
+
break;
|
| 668 |
+
case 2:
|
| 669 |
+
butterfly_ifft_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
|
| 670 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 671 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 672 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 673 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 674 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 675 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 676 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 677 |
+
B,
|
| 678 |
+
H,
|
| 679 |
+
fft_size
|
| 680 |
+
);
|
| 681 |
+
break;
|
| 682 |
+
default:
|
| 683 |
+
printf("Invalid K: %d\n", K);
|
| 684 |
+
break;
|
| 685 |
+
}
|
| 686 |
+
break;
|
| 687 |
+
|
| 688 |
+
case 64:
|
| 689 |
+
gridDim.z = H / TILE_H;
|
| 690 |
+
switch (K)
|
| 691 |
+
{
|
| 692 |
+
case 1:
|
| 693 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 694 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1><<<gridDim, blockDim, 65536>>>(
|
| 695 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 696 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 697 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 698 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 699 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 700 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 701 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 702 |
+
B,
|
| 703 |
+
H,
|
| 704 |
+
fft_size);
|
| 705 |
+
break;
|
| 706 |
+
|
| 707 |
+
case 2:
|
| 708 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 709 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2><<<gridDim, blockDim, 65536>>>(
|
| 710 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 711 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 712 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 713 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 714 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 715 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 716 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 717 |
+
B,
|
| 718 |
+
H,
|
| 719 |
+
fft_size);
|
| 720 |
+
break;
|
| 721 |
+
|
| 722 |
+
case 3:
|
| 723 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 724 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3><<<gridDim, blockDim, 65536>>>(
|
| 725 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 726 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 727 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 728 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 729 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 730 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 731 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 732 |
+
B,
|
| 733 |
+
H,
|
| 734 |
+
fft_size);
|
| 735 |
+
break;
|
| 736 |
+
|
| 737 |
+
case 4:
|
| 738 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 739 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4><<<gridDim, blockDim, 65536>>>(
|
| 740 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 741 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 742 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 743 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 744 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 745 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 746 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 747 |
+
B,
|
| 748 |
+
H,
|
| 749 |
+
fft_size);
|
| 750 |
+
break;
|
| 751 |
+
|
| 752 |
+
default:
|
| 753 |
+
break;
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
break;
|
| 757 |
+
case 128:
|
| 758 |
+
blockDim.x = 32;
|
| 759 |
+
blockDim.y = 8;
|
| 760 |
+
gridDim.x = 256 / (32 * 1024/ (N_M / d_f_size));
|
| 761 |
+
gridDim.z = H / TILE_H;
|
| 762 |
+
|
| 763 |
+
switch (K)
|
| 764 |
+
{
|
| 765 |
+
case 1:
|
| 766 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 767 |
+
|
| 768 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1><<<gridDim, blockDim, 65536 * 2>>>(
|
| 769 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 770 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 771 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 772 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 773 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 774 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 775 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 776 |
+
B,
|
| 777 |
+
H,
|
| 778 |
+
fft_size);
|
| 779 |
+
break;
|
| 780 |
+
|
| 781 |
+
case 2:
|
| 782 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 783 |
+
|
| 784 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2><<<gridDim, blockDim, 65536 * 2>>>(
|
| 785 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 786 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 787 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 788 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 789 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 790 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 791 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 792 |
+
B,
|
| 793 |
+
H,
|
| 794 |
+
fft_size);
|
| 795 |
+
break;
|
| 796 |
+
|
| 797 |
+
case 3:
|
| 798 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 799 |
+
|
| 800 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3><<<gridDim, blockDim, 65536 * 2>>>(
|
| 801 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 802 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 803 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 804 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 805 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 806 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 807 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 808 |
+
B,
|
| 809 |
+
H,
|
| 810 |
+
fft_size);
|
| 811 |
+
break;
|
| 812 |
+
|
| 813 |
+
case 4:
|
| 814 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 815 |
+
|
| 816 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4><<<gridDim, blockDim, 65536 * 2>>>(
|
| 817 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 818 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 819 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 820 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 821 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 822 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 823 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 824 |
+
B,
|
| 825 |
+
H,
|
| 826 |
+
fft_size);
|
| 827 |
+
break;
|
| 828 |
+
|
| 829 |
+
case 5:
|
| 830 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 831 |
+
|
| 832 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5><<<gridDim, blockDim, 65536 * 2>>>(
|
| 833 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 834 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 835 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 836 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 837 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 838 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 839 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 840 |
+
B,
|
| 841 |
+
H,
|
| 842 |
+
fft_size);
|
| 843 |
+
break;
|
| 844 |
+
|
| 845 |
+
case 6:
|
| 846 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 847 |
+
|
| 848 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6><<<gridDim, blockDim, 65536 * 2>>>(
|
| 849 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 850 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 851 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 852 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 853 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 854 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 855 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 856 |
+
B,
|
| 857 |
+
H,
|
| 858 |
+
fft_size);
|
| 859 |
+
break;
|
| 860 |
+
|
| 861 |
+
case 7:
|
| 862 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 863 |
+
|
| 864 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7><<<gridDim, blockDim, 65536 * 2>>>(
|
| 865 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 866 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 867 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 868 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 869 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 870 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 871 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 872 |
+
B,
|
| 873 |
+
H,
|
| 874 |
+
fft_size);
|
| 875 |
+
break;
|
| 876 |
+
|
| 877 |
+
case 8:
|
| 878 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 879 |
+
|
| 880 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8><<<gridDim, blockDim, 65536 * 2>>>(
|
| 881 |
+
static_cast<__half2 *>(x_real.data_ptr()),
|
| 882 |
+
static_cast<__half2 *>(x_imag.data_ptr()),
|
| 883 |
+
static_cast<complex_half_t *>(d_f.data_ptr()),
|
| 884 |
+
static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
|
| 885 |
+
static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
|
| 886 |
+
static_cast<__half2 *>(out_real.data_ptr()),
|
| 887 |
+
out_gate ? static_cast<__half2 *>(out_gate.value().data_ptr()) : nullptr,
|
| 888 |
+
B,
|
| 889 |
+
H,
|
| 890 |
+
fft_size);
|
| 891 |
+
break;
|
| 892 |
+
|
| 893 |
+
default:
|
| 894 |
+
printf("Invalid K: %d\n", K);
|
| 895 |
+
break;
|
| 896 |
+
}
|
| 897 |
+
break;
|
| 898 |
+
|
| 899 |
+
default:
|
| 900 |
+
printf("Invalid d_f_size: %d\n", d_f_size);
|
| 901 |
+
break;
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
return out_real;
|
| 905 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/butterfly_padded_ifft_cuda_bf16.cu
ADDED
|
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cuda_bf16.h>
|
| 10 |
+
#include "shared.h"
|
| 11 |
+
|
| 12 |
+
using namespace nvcuda;
|
| 13 |
+
|
| 14 |
+
template <int TILE_H, int K>
|
| 15 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_64(
|
| 16 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 17 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 18 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 19 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 20 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 21 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 22 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 23 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 24 |
+
uint B,
|
| 25 |
+
uint H,
|
| 26 |
+
int M)
|
| 27 |
+
{
|
| 28 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 29 |
+
const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
|
| 30 |
+
const int in_offset = blockIdx.y * H * 64 * 32 * gridDim.x + blockIdx.z * TILE_H * 64 * 32 * gridDim.x;
|
| 31 |
+
int idx;
|
| 32 |
+
int t_offset;
|
| 33 |
+
int out_t_offset;
|
| 34 |
+
int shared_offset;
|
| 35 |
+
const int N = 64;
|
| 36 |
+
|
| 37 |
+
extern __shared__ __nv_bfloat16 x_real_shared[];
|
| 38 |
+
__nv_bfloat16 *x_imag_shared = &x_real_shared[N * N];
|
| 39 |
+
__nv_bfloat16 *d_f_real_shared = &x_imag_shared[N * N];
|
| 40 |
+
__nv_bfloat16 *d_f_imag_shared = &d_f_real_shared[N * N];
|
| 41 |
+
__nv_bfloat16 *twiddles_real_shared = &d_f_imag_shared[N * N];
|
| 42 |
+
__nv_bfloat16 *twiddles_imag_shared = &twiddles_real_shared[N * N];
|
| 43 |
+
float *out_real_shared = reinterpret_cast<float*>(&twiddles_imag_shared[N * N]);
|
| 44 |
+
|
| 45 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 46 |
+
|
| 47 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[K][4];
|
| 48 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[K][4];
|
| 49 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[4];
|
| 50 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[4];
|
| 51 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[4];
|
| 52 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[4];
|
| 53 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K];
|
| 54 |
+
|
| 55 |
+
// #pragma unroll
|
| 56 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 57 |
+
{
|
| 58 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 59 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 60 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 61 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 62 |
+
|
| 63 |
+
// #pragma unroll
|
| 64 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 65 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_real_shared)[shared_offset] = d_f_real[shared_offset];
|
| 66 |
+
reinterpret_cast<__nv_bfloat162 *>(d_f_imag_shared)[shared_offset] = d_f_imag[shared_offset];
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
__syncthreads();
|
| 70 |
+
|
| 71 |
+
for (int i = 0; i < 4; i++)
|
| 72 |
+
{
|
| 73 |
+
if(i < K){
|
| 74 |
+
#pragma unroll
|
| 75 |
+
for (int j = 0; j < 4; j++)
|
| 76 |
+
{
|
| 77 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 78 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 82 |
+
wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
for (int t = 0; t < TILE_H; t++)
|
| 86 |
+
{
|
| 87 |
+
|
| 88 |
+
out_t_offset = t * M/2;
|
| 89 |
+
t_offset = t * 64 * 32 * gridDim.x;
|
| 90 |
+
|
| 91 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 92 |
+
{
|
| 93 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 94 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 95 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
|
| 96 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
__syncthreads();
|
| 100 |
+
|
| 101 |
+
for (int i = 0; i < 4; i++)
|
| 102 |
+
{
|
| 103 |
+
wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 104 |
+
wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
for (int j = 0; j < 4; j++)
|
| 108 |
+
{
|
| 109 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 110 |
+
{
|
| 111 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 112 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 113 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 114 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
for (int i = 0; i < K; i++)
|
| 119 |
+
{
|
| 120 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 121 |
+
|
| 122 |
+
// bd
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (int k = 0; k < 4; k++)
|
| 125 |
+
{
|
| 126 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 130 |
+
{
|
| 131 |
+
acc_frag_real[i].x[k] = - acc_frag_real[i].x[k];
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
for (int i = 0; i < K; i++)
|
| 136 |
+
{
|
| 137 |
+
// ac - bd
|
| 138 |
+
#pragma unroll
|
| 139 |
+
for (int k = 0; k < 4; k++)
|
| 140 |
+
{
|
| 141 |
+
wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#pragma unroll
|
| 146 |
+
for (int i = 0; i < K; i++)
|
| 147 |
+
{
|
| 148 |
+
wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
__syncthreads();
|
| 152 |
+
|
| 153 |
+
#pragma unroll
|
| 154 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 155 |
+
{
|
| 156 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 157 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 158 |
+
|
| 159 |
+
if(idx < max_idx){
|
| 160 |
+
if(out_gate != nullptr)
|
| 161 |
+
out_real[out_offset + out_t_offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]), out_gate[out_offset + out_t_offset + idx]);
|
| 162 |
+
else
|
| 163 |
+
out_real[out_offset + out_t_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]);
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
__syncthreads();
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
template <int K>
|
| 173 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_32(
|
| 174 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 175 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 176 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 177 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 178 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 179 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 180 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 181 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 182 |
+
uint B,
|
| 183 |
+
uint H,
|
| 184 |
+
int M)
|
| 185 |
+
{
|
| 186 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 187 |
+
const int N = 32;
|
| 188 |
+
int idx;
|
| 189 |
+
int shared_offset;
|
| 190 |
+
|
| 191 |
+
const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 192 |
+
const int in_offset = blockIdx.y * H * 32 * 32 * gridDim.x + blockIdx.z * 32 * 32 * gridDim.x;
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
__shared__ __nv_bfloat16 x_real_shared[32 * 64];
|
| 196 |
+
__shared__ __nv_bfloat16 x_imag_shared[32 * 64];
|
| 197 |
+
__shared__ __nv_bfloat16 d_f_real_shared[32 * 32];
|
| 198 |
+
__shared__ __nv_bfloat16 d_f_imag_shared[32 * 32];
|
| 199 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[32 * 64];
|
| 200 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[32 * 64];
|
| 201 |
+
__shared__ float out_real_shared[32 * 64];
|
| 202 |
+
|
| 203 |
+
// #pragma unroll
|
| 204 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 205 |
+
{
|
| 206 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 207 |
+
int shared_offset = i * 32 + threadIdx.x;
|
| 208 |
+
|
| 209 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[in_offset + idx];
|
| 210 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[in_offset + idx];
|
| 211 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 212 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 213 |
+
|
| 214 |
+
// #pragma unroll
|
| 215 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 216 |
+
d_f_real_shared[shared_offset] = d_f_real[shared_offset];
|
| 217 |
+
d_f_imag_shared[shared_offset] = d_f_imag[shared_offset];
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
__syncthreads();
|
| 221 |
+
|
| 222 |
+
if (threadIdx.y < N/16)
|
| 223 |
+
{
|
| 224 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 225 |
+
|
| 226 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real[K][2];
|
| 227 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag[K][2];
|
| 228 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[2][2];
|
| 229 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[2][2];
|
| 230 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[2][2];
|
| 231 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[2][2];
|
| 232 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K][2];
|
| 233 |
+
|
| 234 |
+
int t = threadIdx.y * 32;
|
| 235 |
+
|
| 236 |
+
for (int i = 0; i < 2; i++)
|
| 237 |
+
{
|
| 238 |
+
for (int j = 0; j < 2; j++)
|
| 239 |
+
{
|
| 240 |
+
if(i < K){
|
| 241 |
+
wmma::load_matrix_sync(a_frag_real[i][j], d_f_real_shared + j * N * 16 + i * 16, N);
|
| 242 |
+
wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag_shared + j * N * 16 + i * 16, N);
|
| 243 |
+
}
|
| 244 |
+
wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 245 |
+
wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 246 |
+
wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 247 |
+
wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
for (int i = 0; i < 2; i++)
|
| 252 |
+
{
|
| 253 |
+
for (int j = 0; j < 2; j++)
|
| 254 |
+
{
|
| 255 |
+
for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
|
| 256 |
+
{
|
| 257 |
+
tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
|
| 258 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
|
| 259 |
+
b_frag_real[i][j].x[k] = tmp_real;
|
| 260 |
+
b_frag_imag[i][j].x[k] = tmp_imag;
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
for (int i = 0; i < K; i++)
|
| 266 |
+
{
|
| 267 |
+
for (int j = 0; j < 2; j++)
|
| 268 |
+
{
|
| 269 |
+
wmma::fill_fragment(acc_frag_real[i][j], 0.0f);
|
| 270 |
+
|
| 271 |
+
// bd
|
| 272 |
+
for (int k = 0; k < 2; k++)
|
| 273 |
+
{
|
| 274 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
|
| 278 |
+
{
|
| 279 |
+
acc_frag_real[i][j].x[k] = - acc_frag_real[i][j].x[k];
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
for (int i = 0; i < K; i++)
|
| 285 |
+
{
|
| 286 |
+
for (int j = 0; j < 2; j++)
|
| 287 |
+
{
|
| 288 |
+
// ac - bd
|
| 289 |
+
for (int k = 0; k < 2; k++)
|
| 290 |
+
{
|
| 291 |
+
wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
for (int i = 0; i < K; i++)
|
| 297 |
+
{
|
| 298 |
+
for (int j = 0; j < 2; j++)
|
| 299 |
+
{
|
| 300 |
+
wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
__syncthreads();
|
| 306 |
+
|
| 307 |
+
#pragma unroll
|
| 308 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 309 |
+
{
|
| 310 |
+
idx = i * 32 * gridDim.x + blockIdx.x * 32 + threadIdx.x;
|
| 311 |
+
shared_offset = i * 32 + threadIdx.x;
|
| 312 |
+
|
| 313 |
+
if(idx < max_idx){
|
| 314 |
+
if(out_gate != nullptr){
|
| 315 |
+
out_real[idx + out_offset] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]), out_gate[idx + out_offset]);
|
| 316 |
+
}else{
|
| 317 |
+
out_real[idx + out_offset] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[shared_offset]);
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
template <int TILE_H, int K>
|
| 326 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_128(
|
| 327 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 328 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 329 |
+
const __nv_bfloat162 *__restrict__ d_f_real,
|
| 330 |
+
const __nv_bfloat162 *__restrict__ d_f_imag,
|
| 331 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 332 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 333 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 334 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 335 |
+
uint B,
|
| 336 |
+
uint H,
|
| 337 |
+
int M)
|
| 338 |
+
{
|
| 339 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 340 |
+
const int out_offset = blockIdx.y * H * M/2 + blockIdx.z * TILE_H * M/2;
|
| 341 |
+
const int in_offset = blockIdx.y * H * 128 * 32 * 2 * gridDim.x + blockIdx.z * TILE_H * 128 * 32 * 2 * gridDim.x;
|
| 342 |
+
const int N = 128;
|
| 343 |
+
int idx;
|
| 344 |
+
int t_offset;
|
| 345 |
+
int out_t_offset;
|
| 346 |
+
int shared_offset;
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
extern __shared__ __nv_bfloat16 real_shared[];
|
| 350 |
+
__nv_bfloat16 *imag_shared = &real_shared[128 * 128];
|
| 351 |
+
__nv_bfloat16 *real_shared_2 = &imag_shared[128 * 128];
|
| 352 |
+
__nv_bfloat16 *imag_shared_2 = &real_shared_2[128 * 128];
|
| 353 |
+
|
| 354 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 355 |
+
|
| 356 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag[K][8];
|
| 357 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real[8];
|
| 358 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag[8];
|
| 359 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real[8];
|
| 360 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag[8];
|
| 361 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real[K];
|
| 362 |
+
|
| 363 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 364 |
+
{
|
| 365 |
+
for(int j=0; j< 2; j++){
|
| 366 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 367 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared_2)[shared_offset] = d_f_real[shared_offset];
|
| 368 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared_2)[shared_offset] = d_f_imag[shared_offset];
|
| 369 |
+
}
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 373 |
+
{
|
| 374 |
+
for(int j=0; j< 2; j++){
|
| 375 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 376 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 377 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 378 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
__syncthreads();
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
for (int i = 0; i < 8; i++){
|
| 386 |
+
wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 387 |
+
wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
for (int t = 0; t < TILE_H; t++)
|
| 392 |
+
{
|
| 393 |
+
|
| 394 |
+
out_t_offset = t * M/2;
|
| 395 |
+
t_offset = t * 128 * 32 * 2 * gridDim.x;
|
| 396 |
+
|
| 397 |
+
for (int i = 0; i < K; i++){
|
| 398 |
+
for (int j = 0; j < 8; j++){
|
| 399 |
+
wmma::load_matrix_sync(a_frag[i][j], imag_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 404 |
+
{
|
| 405 |
+
for(int j=0; j< 2; j++){
|
| 406 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 407 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 408 |
+
reinterpret_cast<__nv_bfloat162*>(real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
|
| 409 |
+
reinterpret_cast<__nv_bfloat162*>(imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
__syncthreads();
|
| 414 |
+
|
| 415 |
+
for (int i = 0; i < 8; i++)
|
| 416 |
+
{
|
| 417 |
+
wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 418 |
+
wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
__syncthreads();
|
| 423 |
+
|
| 424 |
+
for (int j = 0; j < 8; j++)
|
| 425 |
+
{
|
| 426 |
+
for (int k = 0; k < tw_frag_real[j].num_elements; k++)
|
| 427 |
+
{
|
| 428 |
+
tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
|
| 429 |
+
tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
|
| 430 |
+
b_frag_real[j].x[k] = tmp_real;
|
| 431 |
+
b_frag_imag[j].x[k] = tmp_imag;
|
| 432 |
+
}
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
__syncthreads();
|
| 436 |
+
|
| 437 |
+
for (int i = 0; i < K; i++)
|
| 438 |
+
{
|
| 439 |
+
wmma::fill_fragment(acc_frag_real[i], 0.0f);
|
| 440 |
+
|
| 441 |
+
// bd
|
| 442 |
+
#pragma unroll
|
| 443 |
+
for (int k = 0; k < 8; k++)
|
| 444 |
+
{
|
| 445 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_imag[k], acc_frag_real[i]);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
for (int k = 0; k < acc_frag_real[i].num_elements; k++)
|
| 449 |
+
{
|
| 450 |
+
acc_frag_real[i].x[k] = -acc_frag_real[i].x[k];
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
for (int i = 0; i < K; i++){
|
| 455 |
+
for (int j = 0; j < 8; j++){
|
| 456 |
+
wmma::load_matrix_sync(a_frag[i][j], real_shared_2 + j * 128 * 16 + i * 16, 128);
|
| 457 |
+
}
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
for (int i = 0; i < K; i++)
|
| 461 |
+
{
|
| 462 |
+
// ac - bd
|
| 463 |
+
#pragma unroll
|
| 464 |
+
for (int k = 0; k < 8; k++)
|
| 465 |
+
{
|
| 466 |
+
wmma::mma_sync(acc_frag_real[i], a_frag[i][k], b_frag_real[k], acc_frag_real[i]);
|
| 467 |
+
}
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
__syncthreads();
|
| 471 |
+
|
| 472 |
+
#pragma unroll
|
| 473 |
+
for (int i = 0; i < K; i++)
|
| 474 |
+
{
|
| 475 |
+
//wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 476 |
+
wmma::store_matrix_sync(reinterpret_cast<float*>(real_shared) + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
__syncthreads();
|
| 480 |
+
|
| 481 |
+
#pragma unroll
|
| 482 |
+
for (int i = threadIdx.y; i < N; i+=blockDim.y)
|
| 483 |
+
{
|
| 484 |
+
for(int j=0; j< 2; j++){
|
| 485 |
+
idx = i * 32 * 2 * gridDim.x + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
|
| 486 |
+
shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
|
| 487 |
+
if(idx < max_idx){
|
| 488 |
+
if(out_gate != nullptr){
|
| 489 |
+
out_real[idx + out_offset + out_t_offset] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]), out_gate[idx + out_offset + out_t_offset]);
|
| 490 |
+
}else{
|
| 491 |
+
out_real[idx + out_offset + out_t_offset] = __float22bfloat162_rn(reinterpret_cast<float2*>(real_shared)[shared_offset]);
|
| 492 |
+
}
|
| 493 |
+
}
|
| 494 |
+
}
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
__syncthreads();
|
| 498 |
+
}
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
__global__ void butterfly_ifft_padded_cuda_kernel_16(
|
| 503 |
+
const __nv_bfloat162 *__restrict__ x_real,
|
| 504 |
+
const __nv_bfloat162 *__restrict__ x_imag,
|
| 505 |
+
const __nv_bfloat16 *__restrict__ d_f_real,
|
| 506 |
+
const __nv_bfloat16 *__restrict__ d_f_imag,
|
| 507 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_real,
|
| 508 |
+
const __nv_bfloat162 *__restrict__ twiddle_factors_imag,
|
| 509 |
+
__nv_bfloat162 *__restrict__ out_real,
|
| 510 |
+
__nv_bfloat162 *__restrict__ out_gate,
|
| 511 |
+
uint B,
|
| 512 |
+
uint H,
|
| 513 |
+
int M)
|
| 514 |
+
{
|
| 515 |
+
const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
|
| 516 |
+
const int N = 16;
|
| 517 |
+
const int out_offset = blockIdx.y * H * M / 2 + blockIdx.z * M / 2;
|
| 518 |
+
const int offset = blockIdx.y * H * N * blockDim.x * gridDim.x + blockIdx.z * N * blockDim.x * gridDim.x;
|
| 519 |
+
|
| 520 |
+
__shared__ __nv_bfloat16 x_real_shared[N * 64];
|
| 521 |
+
__shared__ __nv_bfloat16 x_imag_shared[N * 64];
|
| 522 |
+
__shared__ __nv_bfloat16 twiddles_real_shared[N * 64];
|
| 523 |
+
__shared__ __nv_bfloat16 twiddles_imag_shared[N * 64];
|
| 524 |
+
__shared__ float out_real_shared[N * 64];
|
| 525 |
+
|
| 526 |
+
// #pragma unroll
|
| 527 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 528 |
+
{
|
| 529 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 530 |
+
int shared_offset = i * blockDim.x + threadIdx.x;
|
| 531 |
+
reinterpret_cast<__nv_bfloat162 *>(x_real_shared)[shared_offset] = x_real[idx + offset];
|
| 532 |
+
reinterpret_cast<__nv_bfloat162 *>(x_imag_shared)[shared_offset] = x_imag[idx + offset];
|
| 533 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
|
| 534 |
+
reinterpret_cast<__nv_bfloat162 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
__syncthreads();
|
| 538 |
+
|
| 539 |
+
if (threadIdx.y < 4)
|
| 540 |
+
{
|
| 541 |
+
__nv_bfloat16 tmp_real, tmp_imag;
|
| 542 |
+
|
| 543 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_real;
|
| 544 |
+
wmma::fragment<wmma::matrix_a, 16, 16, 16, __nv_bfloat16, wmma::col_major> a_frag_imag;
|
| 545 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_real;
|
| 546 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> tw_frag_imag;
|
| 547 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_real;
|
| 548 |
+
wmma::fragment<wmma::matrix_b, 16, 16, 16, __nv_bfloat16, wmma::row_major> b_frag_imag;
|
| 549 |
+
wmma::fragment<wmma::accumulator, 16, 16, 16, float> acc_frag_real;
|
| 550 |
+
|
| 551 |
+
wmma::load_matrix_sync(a_frag_real, d_f_real, N);
|
| 552 |
+
wmma::load_matrix_sync(a_frag_imag, d_f_imag, N);
|
| 553 |
+
wmma::load_matrix_sync(b_frag_real, x_real_shared + threadIdx.y * 16, 64);
|
| 554 |
+
wmma::load_matrix_sync(b_frag_imag, x_imag_shared + threadIdx.y * 16, 64);
|
| 555 |
+
wmma::load_matrix_sync(tw_frag_real, twiddles_real_shared + threadIdx.y * 16, 64);
|
| 556 |
+
wmma::load_matrix_sync(tw_frag_imag, twiddles_imag_shared + threadIdx.y * 16, 64);
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
for (int k = 0; k < tw_frag_real.num_elements; k++)
|
| 560 |
+
{
|
| 561 |
+
tmp_real = __hsub(__hmul(tw_frag_real.x[k], b_frag_real.x[k]), __hmul(tw_frag_imag.x[k], b_frag_imag.x[k]));
|
| 562 |
+
tmp_imag = __hadd(__hmul(tw_frag_real.x[k], b_frag_imag.x[k]), __hmul(tw_frag_imag.x[k], b_frag_real.x[k]));
|
| 563 |
+
b_frag_real.x[k] = tmp_real;
|
| 564 |
+
b_frag_imag.x[k] = tmp_imag;
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
wmma::fill_fragment(acc_frag_real, 0.0f);
|
| 570 |
+
|
| 571 |
+
wmma::mma_sync(acc_frag_real, a_frag_imag, b_frag_imag, acc_frag_real);
|
| 572 |
+
|
| 573 |
+
for(int k=0; k< acc_frag_real.num_elements; k++){
|
| 574 |
+
acc_frag_real.x[k] = - acc_frag_real.x[k];
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
wmma::mma_sync(acc_frag_real, a_frag_real, b_frag_real, acc_frag_real);
|
| 578 |
+
|
| 579 |
+
wmma::store_matrix_sync(out_real_shared + threadIdx.y * 16, acc_frag_real, 64, wmma::mem_row_major);
|
| 580 |
+
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
__syncthreads();
|
| 584 |
+
|
| 585 |
+
#pragma unroll
|
| 586 |
+
for (int i = threadIdx.y; i < N; i++)
|
| 587 |
+
{
|
| 588 |
+
int idx = i * blockDim.x * gridDim.x + blockIdx.x * blockDim.x + threadIdx.x;
|
| 589 |
+
if(idx < max_idx){
|
| 590 |
+
if(out_gate != nullptr){
|
| 591 |
+
out_real[out_offset + idx] = __hmul2(__float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[i * 32 + threadIdx.x]), out_gate[out_offset + idx]);
|
| 592 |
+
}else{
|
| 593 |
+
out_real[out_offset + idx] = __float22bfloat162_rn(reinterpret_cast<float2 *>(out_real_shared)[i * 32 + threadIdx.x]);
|
| 594 |
+
}
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
torch::Tensor butterfly_ifft_padded_bf16_cuda(
|
| 601 |
+
torch::Tensor x_real,
|
| 602 |
+
torch::Tensor x_imag,
|
| 603 |
+
torch::Tensor d_f_real,
|
| 604 |
+
torch::Tensor d_f_imag,
|
| 605 |
+
torch::Tensor twiddle_factors_real,
|
| 606 |
+
torch::Tensor twiddle_factors_imag,
|
| 607 |
+
int fft_size,
|
| 608 |
+
std::optional<at::Tensor> out_gate = std::nullopt
|
| 609 |
+
)
|
| 610 |
+
{
|
| 611 |
+
|
| 612 |
+
uint B = x_real.size(0);
|
| 613 |
+
uint H = x_real.size(1);
|
| 614 |
+
uint N_M = x_real.size(2);
|
| 615 |
+
const int d_f_size = d_f_real.size(0);
|
| 616 |
+
// const int TILE_SIZE = 16;
|
| 617 |
+
|
| 618 |
+
dim3 gridDim;
|
| 619 |
+
dim3 blockDim;
|
| 620 |
+
|
| 621 |
+
// uint N = x_real.size(2);
|
| 622 |
+
gridDim.y = B;
|
| 623 |
+
|
| 624 |
+
blockDim.x = 32;
|
| 625 |
+
blockDim.y = 4;
|
| 626 |
+
gridDim.x = 512 / (32 * 1024/ (N_M / d_f_size));
|
| 627 |
+
gridDim.z = H;
|
| 628 |
+
|
| 629 |
+
const int TILE_H = 16;
|
| 630 |
+
torch::Tensor out_real = torch::empty({B, H, fft_size}, x_real.options());
|
| 631 |
+
const int K = ceil(fft_size / (1.0 * 16 * (N_M / d_f_size)));
|
| 632 |
+
|
| 633 |
+
switch(d_f_size){
|
| 634 |
+
case 16:
|
| 635 |
+
butterfly_ifft_padded_cuda_kernel_16<<<gridDim, blockDim>>>(
|
| 636 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 637 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 638 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 639 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 640 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 641 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 642 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 643 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 644 |
+
B,
|
| 645 |
+
H,
|
| 646 |
+
fft_size
|
| 647 |
+
);
|
| 648 |
+
break;
|
| 649 |
+
case 32:
|
| 650 |
+
switch (K)
|
| 651 |
+
{
|
| 652 |
+
case 1:
|
| 653 |
+
butterfly_ifft_padded_cuda_kernel_32<1><<<gridDim, blockDim>>>(
|
| 654 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 655 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 656 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 657 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 658 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 659 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 660 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 661 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 662 |
+
B,
|
| 663 |
+
H,
|
| 664 |
+
fft_size
|
| 665 |
+
);
|
| 666 |
+
break;
|
| 667 |
+
case 2:
|
| 668 |
+
butterfly_ifft_padded_cuda_kernel_32<2><<<gridDim, blockDim>>>(
|
| 669 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 670 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 671 |
+
static_cast<__nv_bfloat16 *>(d_f_real.data_ptr()),
|
| 672 |
+
static_cast<__nv_bfloat16 *>(d_f_imag.data_ptr()),
|
| 673 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 674 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 675 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 676 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 677 |
+
B,
|
| 678 |
+
H,
|
| 679 |
+
fft_size
|
| 680 |
+
);
|
| 681 |
+
break;
|
| 682 |
+
default:
|
| 683 |
+
printf("Invalid K: %d\n", K);
|
| 684 |
+
break;
|
| 685 |
+
}
|
| 686 |
+
break;
|
| 687 |
+
|
| 688 |
+
case 64:
|
| 689 |
+
gridDim.z = H / TILE_H;
|
| 690 |
+
switch (K)
|
| 691 |
+
{
|
| 692 |
+
case 1:
|
| 693 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 694 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 1><<<gridDim, blockDim, 65536>>>(
|
| 695 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 696 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 697 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 698 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 699 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 700 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 701 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 702 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 703 |
+
B,
|
| 704 |
+
H,
|
| 705 |
+
fft_size);
|
| 706 |
+
break;
|
| 707 |
+
|
| 708 |
+
case 2:
|
| 709 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 710 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 2><<<gridDim, blockDim, 65536>>>(
|
| 711 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 712 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 713 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 714 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 715 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 716 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 717 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 718 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 719 |
+
B,
|
| 720 |
+
H,
|
| 721 |
+
fft_size);
|
| 722 |
+
break;
|
| 723 |
+
|
| 724 |
+
case 3:
|
| 725 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 726 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 3><<<gridDim, blockDim, 65536>>>(
|
| 727 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 728 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 729 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 730 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 731 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 732 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 733 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 734 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 735 |
+
B,
|
| 736 |
+
H,
|
| 737 |
+
fft_size);
|
| 738 |
+
break;
|
| 739 |
+
|
| 740 |
+
case 4:
|
| 741 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536);
|
| 742 |
+
butterfly_ifft_padded_cuda_kernel_64<TILE_H, 4><<<gridDim, blockDim, 65536>>>(
|
| 743 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 744 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 745 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 746 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 747 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 748 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 749 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 750 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 751 |
+
B,
|
| 752 |
+
H,
|
| 753 |
+
fft_size);
|
| 754 |
+
break;
|
| 755 |
+
|
| 756 |
+
default:
|
| 757 |
+
break;
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
break;
|
| 761 |
+
case 128:
|
| 762 |
+
blockDim.x = 32;
|
| 763 |
+
blockDim.y = 8;
|
| 764 |
+
gridDim.x = 256 / (32 * 1024/ (N_M / d_f_size));
|
| 765 |
+
gridDim.z = H / TILE_H;
|
| 766 |
+
|
| 767 |
+
switch (K)
|
| 768 |
+
{
|
| 769 |
+
case 1:
|
| 770 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 771 |
+
|
| 772 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 1><<<gridDim, blockDim, 65536 * 2>>>(
|
| 773 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 774 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 775 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 776 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 777 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 778 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 779 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 780 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 781 |
+
B,
|
| 782 |
+
H,
|
| 783 |
+
fft_size);
|
| 784 |
+
break;
|
| 785 |
+
|
| 786 |
+
case 2:
|
| 787 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 788 |
+
|
| 789 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 2><<<gridDim, blockDim, 65536 * 2>>>(
|
| 790 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 791 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 792 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 793 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 794 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 795 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 796 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 797 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 798 |
+
B,
|
| 799 |
+
H,
|
| 800 |
+
fft_size);
|
| 801 |
+
break;
|
| 802 |
+
|
| 803 |
+
case 3:
|
| 804 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 805 |
+
|
| 806 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 3><<<gridDim, blockDim, 65536 * 2>>>(
|
| 807 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 808 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 809 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 810 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 811 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 812 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 813 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 814 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 815 |
+
B,
|
| 816 |
+
H,
|
| 817 |
+
fft_size);
|
| 818 |
+
break;
|
| 819 |
+
|
| 820 |
+
case 4:
|
| 821 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 822 |
+
|
| 823 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 4><<<gridDim, blockDim, 65536 * 2>>>(
|
| 824 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 825 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 826 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 827 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 828 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 829 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 830 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 831 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 832 |
+
B,
|
| 833 |
+
H,
|
| 834 |
+
fft_size);
|
| 835 |
+
break;
|
| 836 |
+
|
| 837 |
+
case 5:
|
| 838 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 839 |
+
|
| 840 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 5><<<gridDim, blockDim, 65536 * 2>>>(
|
| 841 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 842 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 843 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 844 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 845 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 846 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 847 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 848 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 849 |
+
B,
|
| 850 |
+
H,
|
| 851 |
+
fft_size);
|
| 852 |
+
break;
|
| 853 |
+
|
| 854 |
+
case 6:
|
| 855 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 856 |
+
|
| 857 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 6><<<gridDim, blockDim, 65536 * 2>>>(
|
| 858 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 859 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 860 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 861 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 862 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 863 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 864 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 865 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 866 |
+
B,
|
| 867 |
+
H,
|
| 868 |
+
fft_size);
|
| 869 |
+
break;
|
| 870 |
+
|
| 871 |
+
case 7:
|
| 872 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 873 |
+
|
| 874 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 7><<<gridDim, blockDim, 65536 * 2>>>(
|
| 875 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 876 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 877 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 878 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 879 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 880 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 881 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 882 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 883 |
+
B,
|
| 884 |
+
H,
|
| 885 |
+
fft_size);
|
| 886 |
+
break;
|
| 887 |
+
|
| 888 |
+
case 8:
|
| 889 |
+
cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536 * 2);
|
| 890 |
+
|
| 891 |
+
butterfly_ifft_padded_cuda_kernel_128<TILE_H, 8><<<gridDim, blockDim, 65536 * 2>>>(
|
| 892 |
+
static_cast<__nv_bfloat162 *>(x_real.data_ptr()),
|
| 893 |
+
static_cast<__nv_bfloat162 *>(x_imag.data_ptr()),
|
| 894 |
+
static_cast<__nv_bfloat162 *>(d_f_real.data_ptr()),
|
| 895 |
+
static_cast<__nv_bfloat162 *>(d_f_imag.data_ptr()),
|
| 896 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_real.data_ptr()),
|
| 897 |
+
static_cast<__nv_bfloat162 *>(twiddle_factors_imag.data_ptr()),
|
| 898 |
+
static_cast<__nv_bfloat162 *>(out_real.data_ptr()),
|
| 899 |
+
out_gate ? static_cast<__nv_bfloat162 *>(out_gate.value().data_ptr()) : nullptr,
|
| 900 |
+
B,
|
| 901 |
+
H,
|
| 902 |
+
fft_size);
|
| 903 |
+
break;
|
| 904 |
+
|
| 905 |
+
default:
|
| 906 |
+
printf("Invalid K: %d\n", K);
|
| 907 |
+
break;
|
| 908 |
+
}
|
| 909 |
+
break;
|
| 910 |
+
|
| 911 |
+
default:
|
| 912 |
+
printf("Invalid d_f_size: %d\n", d_f_size);
|
| 913 |
+
break;
|
| 914 |
+
}
|
| 915 |
+
|
| 916 |
+
return out_real;
|
| 917 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/butterfly/shared.h
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
#include <stdio.h>
|
| 7 |
+
#include <mma.h>
|
| 8 |
+
#include <cuda_fp16.h>
|
| 9 |
+
#include <cub/block/block_load.cuh>
|
| 10 |
+
#include <cub/block/block_store.cuh>
|
| 11 |
+
using namespace nvcuda;
|
| 12 |
+
|
| 13 |
+
using complex_half_t = typename c10::complex<at::Half>;
|
| 14 |
+
using complex_bhalf_t = typename c10::complex<at::BFloat16>;
|
| 15 |
+
|
| 16 |
+
#define WMMA_M 16
|
| 17 |
+
#define WMMA_N 16
|
| 18 |
+
#define WMMA_K 16
|
| 19 |
+
#define WARP_SIZE 32
|
| 20 |
+
|
| 21 |
+
#ifndef MONARCH_CUDA_H_
|
| 22 |
+
#define MONARCH_CUDA_H_
|
| 23 |
+
|
| 24 |
+
__device__ __forceinline__ float2
|
| 25 |
+
|
| 26 |
+
operator+( float2 lhs, float2 rhs)
|
| 27 |
+
|
| 28 |
+
{
|
| 29 |
+
|
| 30 |
+
float2 res = { lhs.x + rhs.x , lhs.y + rhs.y };
|
| 31 |
+
|
| 32 |
+
return res;
|
| 33 |
+
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
__device__ __forceinline__ float2
|
| 38 |
+
|
| 39 |
+
operator-( float2 lhs, float2 rhs)
|
| 40 |
+
|
| 41 |
+
{
|
| 42 |
+
|
| 43 |
+
float2 res = { lhs.x - rhs.x , lhs.y - rhs.y };
|
| 44 |
+
|
| 45 |
+
return res;
|
| 46 |
+
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
__device__ __forceinline__ float2
|
| 50 |
+
|
| 51 |
+
operator*( float2 lhs, float2 rhs)
|
| 52 |
+
|
| 53 |
+
{
|
| 54 |
+
|
| 55 |
+
float2 res = { lhs.x * rhs.x , lhs.y * rhs.y };
|
| 56 |
+
|
| 57 |
+
return res;
|
| 58 |
+
|
| 59 |
+
}
|
| 60 |
+
#endif
|
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d.h
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
#include <torch/extension.h>
|
| 4 |
+
|
| 5 |
+
#include <vector>
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 10 |
+
#define CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x) TORCH_CHECK(x.dtype() == torch::kFloat16 || x.dtype() == torch::kBFloat16 || x.dtype() == torch::kFloat32, #x " must be float16 or bfloat16 or float32")
|
| 11 |
+
#define CHECK_SAME_TYPE(x, y) TORCH_CHECK(x.dtype() == y.dtype(), #x " and " #y " must have the same dtype")
|
| 12 |
+
|
| 13 |
+
#define CHECK_INPUT(x) \
|
| 14 |
+
CHECK_CUDA(x); \
|
| 15 |
+
CHECK_CONTIGUOUS(x); \
|
| 16 |
+
CHECK_IS_HALF_OR_BFLOAT_OR_FLOAT(x)
|
| 17 |
+
|
| 18 |
+
torch::Tensor conv1d_cuda_bhl(
|
| 19 |
+
torch::Tensor u,
|
| 20 |
+
torch::Tensor weight,
|
| 21 |
+
torch::Tensor bias,
|
| 22 |
+
uint padding);
|
| 23 |
+
|
| 24 |
+
torch::Tensor conv1d_cuda_blh(
|
| 25 |
+
torch::Tensor u,
|
| 26 |
+
torch::Tensor weight,
|
| 27 |
+
torch::Tensor bias,
|
| 28 |
+
uint padding);
|
| 29 |
+
|
| 30 |
+
std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
|
| 31 |
+
torch::Tensor dout,
|
| 32 |
+
torch::Tensor input,
|
| 33 |
+
torch::Tensor weight,
|
| 34 |
+
torch::Tensor bias,
|
| 35 |
+
uint padding
|
| 36 |
+
);
|
| 37 |
+
|
| 38 |
+
std::vector<torch::Tensor> conv1d_backward_blh_cuda(
|
| 39 |
+
torch::Tensor dout,
|
| 40 |
+
torch::Tensor input,
|
| 41 |
+
torch::Tensor weight,
|
| 42 |
+
torch::Tensor bias,
|
| 43 |
+
uint padding
|
| 44 |
+
);
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
torch::Tensor conv1d_fwd(
|
| 48 |
+
torch::Tensor u,
|
| 49 |
+
torch::Tensor weight,
|
| 50 |
+
torch::Tensor bias,
|
| 51 |
+
uint padding,
|
| 52 |
+
bool is_bhl)
|
| 53 |
+
{
|
| 54 |
+
CHECK_INPUT(u);
|
| 55 |
+
CHECK_INPUT(weight);
|
| 56 |
+
CHECK_INPUT(bias);
|
| 57 |
+
CHECK_SAME_TYPE(weight, bias);
|
| 58 |
+
|
| 59 |
+
int k;
|
| 60 |
+
|
| 61 |
+
if(is_bhl){
|
| 62 |
+
k = weight.size(1);
|
| 63 |
+
}else{
|
| 64 |
+
k = weight.size(0);
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
TORCH_CHECK(k % 2 == 1, "Filter size must be odd number");
|
| 68 |
+
|
| 69 |
+
if(is_bhl){
|
| 70 |
+
return conv1d_cuda_bhl(u, weight, bias, padding);
|
| 71 |
+
}else{
|
| 72 |
+
return conv1d_cuda_blh(u, weight, bias, padding);
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
std::vector<torch::Tensor> conv1d_bwd(
|
| 77 |
+
torch::Tensor dout,
|
| 78 |
+
torch::Tensor input,
|
| 79 |
+
torch::Tensor weight,
|
| 80 |
+
torch::Tensor bias,
|
| 81 |
+
uint padding,
|
| 82 |
+
bool is_bhl)
|
| 83 |
+
{
|
| 84 |
+
CHECK_INPUT(dout);
|
| 85 |
+
CHECK_INPUT(input);
|
| 86 |
+
CHECK_INPUT(weight);
|
| 87 |
+
CHECK_INPUT(bias);
|
| 88 |
+
CHECK_SAME_TYPE(weight, bias);
|
| 89 |
+
CHECK_SAME_TYPE(dout, input);
|
| 90 |
+
|
| 91 |
+
if(is_bhl){
|
| 92 |
+
return conv1d_backward_bhl_cuda(dout, input, weight, bias, padding);
|
| 93 |
+
} else{
|
| 94 |
+
return conv1d_backward_blh_cuda(dout, input, weight, bias, padding);
|
| 95 |
+
}
|
| 96 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bhl.cu
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
// Simple 1D depthwise convolution implementation with dilation and stride = 1
|
| 4 |
+
#include "shared.h"
|
| 5 |
+
|
| 6 |
+
const uint BX = 256;
|
| 7 |
+
const uint BY = 1;
|
| 8 |
+
const uint BZ = 1;
|
| 9 |
+
|
| 10 |
+
const uint TILE_SIZE_L = 4;
|
| 11 |
+
const uint TILE_SIZE_D = 1;
|
| 12 |
+
|
| 13 |
+
template<typename T, typename U>
|
| 14 |
+
__forceinline__ __device__ T _conv1d_k_3(const T* u, const U* weights, const U* bias, uint padding, uint l, uint d, uint L, uint D, uint K)
|
| 15 |
+
{
|
| 16 |
+
T tmp;
|
| 17 |
+
T weight;
|
| 18 |
+
|
| 19 |
+
set_value(&tmp, bias[d]);
|
| 20 |
+
|
| 21 |
+
int idx = l - padding;
|
| 22 |
+
|
| 23 |
+
if(idx >= 0 && idx < L){
|
| 24 |
+
set_value(&weight, weights[0]);
|
| 25 |
+
tmp = __hfma(u[d * L + idx], weight, tmp);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
idx++;
|
| 29 |
+
if(idx >= 0 && idx < L){
|
| 30 |
+
set_value(&weight, weights[1]);
|
| 31 |
+
tmp = __hfma(u[d * L + idx], weight, tmp);
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
idx++;
|
| 35 |
+
if(idx >= 0 && idx < L){
|
| 36 |
+
set_value(&weight, weights[2]);
|
| 37 |
+
tmp = __hfma(u[d * L + idx], weight, tmp);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return tmp;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
template<typename T, typename U>
|
| 44 |
+
__global__ void conv1d_kernel(
|
| 45 |
+
const T *__restrict__ u,
|
| 46 |
+
const U *__restrict__ weights,
|
| 47 |
+
const U *__restrict__ bias,
|
| 48 |
+
T *__restrict__ out,
|
| 49 |
+
uint padding,
|
| 50 |
+
uint B,
|
| 51 |
+
uint L,
|
| 52 |
+
uint D,
|
| 53 |
+
uint K,
|
| 54 |
+
uint L_out
|
| 55 |
+
)
|
| 56 |
+
{
|
| 57 |
+
const int b = blockIdx.z * blockDim.z + threadIdx.z;
|
| 58 |
+
const int d = blockIdx.y * blockDim.y * TILE_SIZE_D + threadIdx.y;
|
| 59 |
+
const int l_offset = blockIdx.x * blockDim.x * TILE_SIZE_L + threadIdx.x;
|
| 60 |
+
|
| 61 |
+
T tmp;
|
| 62 |
+
T weight;
|
| 63 |
+
|
| 64 |
+
int idx;
|
| 65 |
+
int l;
|
| 66 |
+
|
| 67 |
+
for(int l_tile = 0; l_tile < TILE_SIZE_L; l_tile++){
|
| 68 |
+
l = l_offset + l_tile * blockDim.x;
|
| 69 |
+
|
| 70 |
+
set_value(&tmp, bias[d]);
|
| 71 |
+
|
| 72 |
+
if(d < D && l < L_out && b < B){
|
| 73 |
+
if(K == 3){
|
| 74 |
+
out[b * L_out * D + d * L_out + l] = _conv1d_k_3(u + b * L * D, weights + d * K, bias, padding, l, d, L, D, K);
|
| 75 |
+
} else{
|
| 76 |
+
for(int k = 0; k < K; k++){
|
| 77 |
+
idx = l - padding + k;
|
| 78 |
+
if(idx >= 0 && idx < L){
|
| 79 |
+
set_value(&weight, weights[d * K + k]);
|
| 80 |
+
tmp = __hfma(u[b * L_out * D + d * L + idx], weight, tmp);
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
out[b * L_out * D + d * L_out + l] = tmp;
|
| 84 |
+
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
torch::Tensor conv1d_cuda_bhl(
|
| 92 |
+
torch::Tensor u,
|
| 93 |
+
torch::Tensor weight,
|
| 94 |
+
torch::Tensor bias,
|
| 95 |
+
uint padding)
|
| 96 |
+
{
|
| 97 |
+
const uint b = u.size(0);
|
| 98 |
+
const uint d = u.size(1);
|
| 99 |
+
const uint l = u.size(2);
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
const uint k = weight.size(1);
|
| 103 |
+
|
| 104 |
+
uint l_out = (l + 2 * padding - k + 1);
|
| 105 |
+
|
| 106 |
+
dim3 blockDims(BX, BY, BZ);
|
| 107 |
+
|
| 108 |
+
dim3 gridDims(ceil(l_out * 1.0 / (BX * TILE_SIZE_L) ), ceil((d * 1.0) / (BY * TILE_SIZE_D)), ceil((b * 1.0) / BZ));
|
| 109 |
+
|
| 110 |
+
torch::Tensor out = torch::empty({b, d, l_out}, u.options());
|
| 111 |
+
|
| 112 |
+
DISPATCH_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), weight.scalar_type(),
|
| 113 |
+
"depthwise conv 1d fwd bhl",
|
| 114 |
+
([&]
|
| 115 |
+
{ conv1d_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
|
| 116 |
+
static_cast<input_t *>(u.data_ptr()),
|
| 117 |
+
static_cast<weight_t *>(weight.data_ptr()),
|
| 118 |
+
static_cast<weight_t *>(bias.data_ptr()),
|
| 119 |
+
static_cast<input_t *>(out.data_ptr()),
|
| 120 |
+
padding,
|
| 121 |
+
b,
|
| 122 |
+
l,
|
| 123 |
+
d,
|
| 124 |
+
k,
|
| 125 |
+
l_out
|
| 126 |
+
);
|
| 127 |
+
}
|
| 128 |
+
)
|
| 129 |
+
);
|
| 130 |
+
|
| 131 |
+
return out;
|
| 132 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_blh.cu
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
|
| 3 |
+
// Simple 1D depthwise convolution implementation with dilation and stride = 1
|
| 4 |
+
|
| 5 |
+
#include "shared.h"
|
| 6 |
+
|
| 7 |
+
//For max perf, tune for your GPU and batch size, and datatype etc
|
| 8 |
+
const uint BX = 512;
|
| 9 |
+
const uint BY = 1;
|
| 10 |
+
const uint BZ = 1;
|
| 11 |
+
|
| 12 |
+
const uint TILE_SIZE_Y = 4;
|
| 13 |
+
const uint TILE_SIZE_X = 2;
|
| 14 |
+
|
| 15 |
+
// Trick to do padding in place without actually creating a new tensor
|
| 16 |
+
__forceinline__ __device__ __half2 get_u(const __half2 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
|
| 17 |
+
{
|
| 18 |
+
return l + k < p || l + k > L_eff - (p + 1) ? __float2half2_rn(0.0f) : u[b * L * D + (l + k - p) * D + d];
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
__forceinline__ __device__ __nv_bfloat162 get_u(const __nv_bfloat162 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
|
| 23 |
+
{
|
| 24 |
+
return l + k < p || l + k > L_eff - (p + 1) ? __float2bfloat162_rn(0.0f) : u[b * L * D + (l + k - p) * D + d];
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
__forceinline__ __device__ float2 get_u(const float2 *__restrict__ u, uint L_eff, uint l, uint p, uint b, uint k, uint d, uint L, uint D, uint K)
|
| 28 |
+
{
|
| 29 |
+
return l + k < p || l + k > L_eff - (p + 1) ? make_float2(0.0f, 0.0f) : u[b * L * D + (l + k - p) * D + d];
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
//manually unrolling loop for k = 3 leads to good perf, can easily extend for other values of k if need be
|
| 34 |
+
template<typename T, typename U>
|
| 35 |
+
__forceinline__ __device__ T _conv1d_k_3(const T* u, const U* weights, const U* bias, T* out, uint padding, uint b, uint l, uint d, uint t, uint L, uint D, uint K, uint L_eff, uint L_out)
|
| 36 |
+
{
|
| 37 |
+
|
| 38 |
+
T tmp;
|
| 39 |
+
T weight;
|
| 40 |
+
set_value(&tmp, bias[d]);
|
| 41 |
+
|
| 42 |
+
set_value(&weight, weights[0 * D + d]);
|
| 43 |
+
tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, 0, d, L, D, K), weight, tmp);
|
| 44 |
+
|
| 45 |
+
set_value(&weight, weights[1 * D + d]);
|
| 46 |
+
tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, 1, d, L, D, K), weight, tmp);
|
| 47 |
+
|
| 48 |
+
set_value(&weight, weights[2 * D + d]);
|
| 49 |
+
out[b * D * L_out + (l + t) * D + d] = __hfma2(get_u(u, L_eff, l + t, padding, b, 2, d, L, D, K), weight, tmp);
|
| 50 |
+
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
template<typename T, typename U>
|
| 54 |
+
__global__ void conv1d_kernel_k_3(
|
| 55 |
+
const T *__restrict__ u,
|
| 56 |
+
const U *__restrict__ weights,
|
| 57 |
+
const U *__restrict__ bias,
|
| 58 |
+
T *__restrict__ out,
|
| 59 |
+
uint padding,
|
| 60 |
+
uint B,
|
| 61 |
+
uint L,
|
| 62 |
+
uint L_out,
|
| 63 |
+
uint L_eff,
|
| 64 |
+
uint D,
|
| 65 |
+
uint K)
|
| 66 |
+
{
|
| 67 |
+
const int d_block = blockIdx.x * blockDim.x * TILE_SIZE_X;
|
| 68 |
+
const int l = blockIdx.y * blockDim.y * TILE_SIZE_Y + threadIdx.y * TILE_SIZE_Y;
|
| 69 |
+
const int b = blockIdx.z * blockDim.z + threadIdx.z;
|
| 70 |
+
|
| 71 |
+
int d;
|
| 72 |
+
|
| 73 |
+
#pragma unroll
|
| 74 |
+
for (int i = 0; i < TILE_SIZE_X; i++)
|
| 75 |
+
{
|
| 76 |
+
d = d_block + threadIdx.x + i * BX;
|
| 77 |
+
|
| 78 |
+
if (d < D && b < B){
|
| 79 |
+
#pragma unroll
|
| 80 |
+
for (int t = 0; t < TILE_SIZE_Y; t++){
|
| 81 |
+
if (l + t < L_eff - K + 1)
|
| 82 |
+
{
|
| 83 |
+
_conv1d_k_3(u, weights, bias, out, padding, b, l, d, t, L, D, K, L_eff, L_out);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
template<typename T, typename U>
|
| 91 |
+
__global__ void conv1d_kernel(
|
| 92 |
+
const T *__restrict__ u,
|
| 93 |
+
const U *__restrict__ weights,
|
| 94 |
+
const U *__restrict__ bias,
|
| 95 |
+
T *__restrict__ out,
|
| 96 |
+
uint padding,
|
| 97 |
+
uint B,
|
| 98 |
+
uint L,
|
| 99 |
+
uint L_out,
|
| 100 |
+
uint L_eff,
|
| 101 |
+
uint D,
|
| 102 |
+
uint K)
|
| 103 |
+
{
|
| 104 |
+
const int d_block = blockIdx.x * blockDim.x * TILE_SIZE_X;
|
| 105 |
+
const int l = blockIdx.y * blockDim.y * TILE_SIZE_Y + threadIdx.y * TILE_SIZE_Y;
|
| 106 |
+
const int b = blockIdx.z * blockDim.z + threadIdx.z;
|
| 107 |
+
|
| 108 |
+
int d;
|
| 109 |
+
T tmp;
|
| 110 |
+
T weight;
|
| 111 |
+
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (int i = 0; i < TILE_SIZE_X; i++)
|
| 114 |
+
{
|
| 115 |
+
d = d_block + threadIdx.x + i * BX;
|
| 116 |
+
|
| 117 |
+
if (d < D && b < B){
|
| 118 |
+
#pragma unroll
|
| 119 |
+
for (int t = 0; t < TILE_SIZE_Y; t++){
|
| 120 |
+
if (l + t < L_eff - K + 1)
|
| 121 |
+
{
|
| 122 |
+
set_value(&tmp, bias[d]);
|
| 123 |
+
|
| 124 |
+
for(int k = 0; k < K; k++){
|
| 125 |
+
set_value(&weight, weights[k * D + d]);
|
| 126 |
+
|
| 127 |
+
tmp = __hfma2(get_u(u, L_eff, l + t, padding, b, k, d, L, D, K), weight, tmp);
|
| 128 |
+
}
|
| 129 |
+
out[b * D * L_out + (l + t) * D + d] = tmp;
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
torch::Tensor conv1d_cuda_blh(
|
| 137 |
+
torch::Tensor u,
|
| 138 |
+
torch::Tensor weight,
|
| 139 |
+
torch::Tensor bias,
|
| 140 |
+
uint padding)
|
| 141 |
+
{
|
| 142 |
+
const uint b = u.size(0);
|
| 143 |
+
const uint l = u.size(1);
|
| 144 |
+
const uint d = u.size(2);
|
| 145 |
+
|
| 146 |
+
const uint k = weight.size(0);
|
| 147 |
+
|
| 148 |
+
uint l_eff = l + 2 * padding;
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
dim3 blockDims(BX, BY, BZ);
|
| 153 |
+
|
| 154 |
+
dim3 gridDims(ceil(d * 1.0 / (BX * TILE_SIZE_X * 2) ), ceil((l_eff - k + 1) * 1.0 / (BY * TILE_SIZE_Y)), ceil(b * 1.0 / BZ));
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
uint l_out = (l + 2 * padding - k + 1);
|
| 158 |
+
|
| 159 |
+
torch::Tensor out = torch::empty({b, l_out, d}, u.options());
|
| 160 |
+
|
| 161 |
+
//calling seperate kernels for k=3 and k!=3 leads to better perf
|
| 162 |
+
if(k==3){
|
| 163 |
+
DISPATCH_FLOAT2_AND_HALF2_AND_BF162(u.scalar_type(), weight.scalar_type(),
|
| 164 |
+
"depthwise conv 1d fwd blh",
|
| 165 |
+
([&]
|
| 166 |
+
{ conv1d_kernel_k_3<input_t, weight_t><<<gridDims, blockDims>>>(
|
| 167 |
+
static_cast<input_t *>(u.data_ptr()),
|
| 168 |
+
static_cast<weight_t *>(weight.data_ptr()),
|
| 169 |
+
static_cast<weight_t *>(bias.data_ptr()),
|
| 170 |
+
static_cast<input_t *>(out.data_ptr()),
|
| 171 |
+
padding,
|
| 172 |
+
b,
|
| 173 |
+
l,
|
| 174 |
+
l_out,
|
| 175 |
+
l_eff,
|
| 176 |
+
ceil(d/2),
|
| 177 |
+
k);
|
| 178 |
+
}
|
| 179 |
+
)
|
| 180 |
+
);
|
| 181 |
+
}else{
|
| 182 |
+
DISPATCH_FLOAT2_AND_HALF2_AND_BF162(u.scalar_type(), weight.scalar_type(),
|
| 183 |
+
"depthwise conv 1d fwd blh",
|
| 184 |
+
([&]
|
| 185 |
+
{ conv1d_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
|
| 186 |
+
static_cast<input_t *>(u.data_ptr()),
|
| 187 |
+
static_cast<weight_t *>(weight.data_ptr()),
|
| 188 |
+
static_cast<weight_t *>(bias.data_ptr()),
|
| 189 |
+
static_cast<input_t *>(out.data_ptr()),
|
| 190 |
+
padding,
|
| 191 |
+
b,
|
| 192 |
+
l,
|
| 193 |
+
l_out,
|
| 194 |
+
l_eff,
|
| 195 |
+
ceil(d/2),
|
| 196 |
+
k);
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
);
|
| 200 |
+
}
|
| 201 |
+
return out;
|
| 202 |
+
}
|
overlay/kernels/cuda/flashfftconv/csrc/conv1d/conv1d_bwd_cuda_bhl.cu
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) 2023 Dan Fu, Hermann Kumbong
|
| 2 |
+
#include "shared.h"
|
| 3 |
+
|
| 4 |
+
const uint BX = 128;
|
| 5 |
+
const uint BY = 1;
|
| 6 |
+
const uint BZ = 1;
|
| 7 |
+
|
| 8 |
+
const uint TILE_SIZE = 4;
|
| 9 |
+
|
| 10 |
+
template <typename input_t, typename weight_t>
|
| 11 |
+
__global__ void conv1d_backward_kernel(
|
| 12 |
+
const input_t* __restrict__ dout,
|
| 13 |
+
const input_t* __restrict__ u,
|
| 14 |
+
const weight_t* __restrict__ weights,
|
| 15 |
+
input_t* __restrict__ du,
|
| 16 |
+
input_t* __restrict__ dk,
|
| 17 |
+
uint B,
|
| 18 |
+
uint L,
|
| 19 |
+
uint D,
|
| 20 |
+
uint K,
|
| 21 |
+
uint P
|
| 22 |
+
)
|
| 23 |
+
{
|
| 24 |
+
const int b = blockIdx.z;
|
| 25 |
+
const int d = blockIdx.y;
|
| 26 |
+
const int l = blockIdx.x;
|
| 27 |
+
|
| 28 |
+
//construct the du matrix
|
| 29 |
+
if(b < B && d < D && l == 0){
|
| 30 |
+
for(int j = threadIdx.x; j < L; j += blockDim.x)
|
| 31 |
+
{
|
| 32 |
+
input_t sum;
|
| 33 |
+
set_value(&sum, 0.0f);
|
| 34 |
+
input_t weight;
|
| 35 |
+
|
| 36 |
+
for(int k = 0; k < K ; k++)
|
| 37 |
+
{
|
| 38 |
+
int idx = - P + k + j;
|
| 39 |
+
|
| 40 |
+
if(idx >= 0 && idx < L){
|
| 41 |
+
set_value(&weight, weights[d * K + K - (k +1)]);
|
| 42 |
+
sum = __hfma(dout[b * D * L + d * L + idx], weight, sum);
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
du[b * D * L + d * L + j] = sum;
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
const int k = blockIdx.x;
|
| 50 |
+
input_t tmp;
|
| 51 |
+
//construct the dk matrix
|
| 52 |
+
if(b < B && d < D && k < K)
|
| 53 |
+
{
|
| 54 |
+
for(int j = threadIdx.x; j < L; j += blockDim.x)
|
| 55 |
+
{
|
| 56 |
+
if(k - P + j < 0 || k - P + j >= L){
|
| 57 |
+
set_value(&dk[b * D * K * L + d * K * L + k * L + j], 0.0f);
|
| 58 |
+
|
| 59 |
+
}else{
|
| 60 |
+
set_value(&dk[b * D * K * L + d * K * L + k * L + j], u[b * D * L + d * L + k - P + j]);
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
std::vector<torch::Tensor> conv1d_backward_bhl_cuda(
|
| 68 |
+
torch::Tensor dout,
|
| 69 |
+
torch::Tensor u,
|
| 70 |
+
torch::Tensor weight,
|
| 71 |
+
torch::Tensor bias,
|
| 72 |
+
uint padding)
|
| 73 |
+
{
|
| 74 |
+
const uint b = u.size(0);
|
| 75 |
+
const uint d = u.size(1);
|
| 76 |
+
const uint l = u.size(2);
|
| 77 |
+
|
| 78 |
+
const uint k = weight.squeeze().size(1);
|
| 79 |
+
|
| 80 |
+
dim3 blockDims(BX, 1, 1);
|
| 81 |
+
|
| 82 |
+
dim3 gridDims(l, d, b);
|
| 83 |
+
|
| 84 |
+
torch::Tensor du = torch::empty({b, d, l}, u.options());
|
| 85 |
+
torch::Tensor dk = torch::empty({b, d, k, l}, dout.options());
|
| 86 |
+
torch::Tensor dbias = dout.sum(-1).sum(0);
|
| 87 |
+
|
| 88 |
+
DISPATCH_FLOAT_AND_HALF_AND_BF16(dout.scalar_type(), weight.scalar_type(),
|
| 89 |
+
"depthwise conv 1d backward bhl",
|
| 90 |
+
([&]
|
| 91 |
+
{ conv1d_backward_kernel<input_t, weight_t><<<gridDims, blockDims>>>(
|
| 92 |
+
static_cast<input_t *>(dout.data_ptr()),
|
| 93 |
+
static_cast<input_t *>(u.data_ptr()),
|
| 94 |
+
static_cast<weight_t *>(weight.data_ptr()),
|
| 95 |
+
static_cast<input_t *>(du.data_ptr()),
|
| 96 |
+
static_cast<input_t *>(dk.data_ptr()),
|
| 97 |
+
b,
|
| 98 |
+
l,
|
| 99 |
+
d,
|
| 100 |
+
k,
|
| 101 |
+
padding);
|
| 102 |
+
}
|
| 103 |
+
)
|
| 104 |
+
);
|
| 105 |
+
return {du, torch::matmul(dk, dout.unsqueeze(-1)).squeeze(-1).sum(0).to(weight.type()), dbias};
|
| 106 |
+
}
|