File size: 1,144 Bytes
7298fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# /// script
# requires-python = ">=3.10"
# dependencies = [
#   "torch",
#   "triton",
#   "kernels",
# ]
# ///

import os
from pathlib import Path
import sys

import torch
from kernels import get_kernel, get_local_kernel


def load_hydra_kernel():
    if os.environ.get("HYDRA_USE_HUB") == "1":
        return get_kernel("Frosty40/hydra")

    root = Path(__file__).resolve().parent
    for variant in (root / "build", root):
        if (variant / "metadata.json").exists():
            return get_local_kernel(variant)

    sys.path.insert(0, str(root / "torch-ext"))
    import hydra

    return hydra


def main() -> None:
    if not torch.cuda.is_available():
        raise SystemExit("Hydra requires CUDA for this example")

    kernel = load_hydra_kernel()
    q = torch.randn(1, 32, 1, 128, device="cuda", dtype=torch.bfloat16)
    k = torch.randn(1, 8, 8192, 128, device="cuda", dtype=torch.bfloat16)
    v = torch.randn(1, 8, 8192, 128, device="cuda", dtype=torch.bfloat16)

    out = kernel.hydra(q, k, v)
    print(f"Hydra decode: {tuple(q.shape)} x {tuple(k.shape)} -> {tuple(out.shape)}")


if __name__ == "__main__":
    main()