galqiwi commited on
Commit
17fbd3f
·
1 Parent(s): 0154673

Update Nix kernel matrix config and safetensors import

Browse files
Files changed (3) hide show
  1. build.toml +7 -7
  2. flake.lock +117 -0
  3. torch-ext/higgs_kernels/__init__.py +8 -3
build.toml CHANGED
@@ -1,11 +1,11 @@
1
  [general]
2
  name = "higgs-kernels"
3
- universal = false
4
 
5
  [torch]
6
  src = [
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
  ]
10
 
11
  [kernel.dequant]
@@ -13,12 +13,12 @@ backend = "cuda"
13
  depends = ["torch"]
14
  src = ["csrc/dequant.cu"]
15
 
16
- [kernel.quant_f16]
17
  backend = "cuda"
18
  depends = ["torch"]
19
- src = ["csrc/quant_f16.cu"]
20
 
21
- [kernel.quant_bf16]
22
  backend = "cuda"
23
  depends = ["torch"]
24
- src = ["csrc/quant_bf16.cu"]
 
1
  [general]
2
  name = "higgs-kernels"
3
+ backends = ["cuda"]
4
 
5
  [torch]
6
  src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
  ]
10
 
11
  [kernel.dequant]
 
13
  depends = ["torch"]
14
  src = ["csrc/dequant.cu"]
15
 
16
+ [kernel.quant_bf16]
17
  backend = "cuda"
18
  depends = ["torch"]
19
+ src = ["csrc/quant_bf16.cu"]
20
 
21
+ [kernel.quant_f16]
22
  backend = "cuda"
23
  depends = ["torch"]
24
+ src = ["csrc/quant_f16.cu"]
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1765121682,
6
+ "narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rust-overlay": "rust-overlay"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1771009495,
45
+ "narHash": "sha256-/2nRZQbiIvlBoFShfpnoOD27ZgeYBV26bI6E5WsBJws=",
46
+ "owner": "huggingface",
47
+ "repo": "kernels",
48
+ "rev": "b079fd8c66612177cc8edd13292613abb4de994c",
49
+ "type": "github"
50
+ },
51
+ "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernels",
54
+ "type": "github"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1766341660,
60
+ "narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
61
+ "owner": "NixOS",
62
+ "repo": "nixpkgs",
63
+ "rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "NixOS",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "root": {
74
+ "inputs": {
75
+ "kernel-builder": "kernel-builder"
76
+ }
77
+ },
78
+ "rust-overlay": {
79
+ "inputs": {
80
+ "nixpkgs": [
81
+ "kernel-builder",
82
+ "nixpkgs"
83
+ ]
84
+ },
85
+ "locked": {
86
+ "lastModified": 1769050281,
87
+ "narHash": "sha256-1H8DN4UZgEUqPUA5ecHOufLZMscJ4IlcGaEftaPtpBY=",
88
+ "owner": "oxalica",
89
+ "repo": "rust-overlay",
90
+ "rev": "6deef0585c52d9e70f96b6121207e1496d4b0c49",
91
+ "type": "github"
92
+ },
93
+ "original": {
94
+ "owner": "oxalica",
95
+ "repo": "rust-overlay",
96
+ "type": "github"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
torch-ext/higgs_kernels/__init__.py CHANGED
@@ -4,14 +4,19 @@ import functools
4
  import torch
5
  from ._ops import ops
6
 
7
- import safetensors.torch
8
-
9
  PKG_PATH = os.path.dirname(os.path.realpath(__file__))
10
 
11
 
12
  @functools.cache
13
  def load_optimal_grid_2_256(device="cpu", dtype=torch.float16):
14
- return safetensors.torch.load_file(
 
 
 
 
 
 
 
15
  os.path.join(PKG_PATH, "grids.safetensors"), device=device
16
  )["2_256"].to(dtype)
17
 
 
4
  import torch
5
  from ._ops import ops
6
 
 
 
7
  PKG_PATH = os.path.dirname(os.path.realpath(__file__))
8
 
9
 
10
  @functools.cache
11
  def load_optimal_grid_2_256(device="cpu", dtype=torch.float16):
12
+ try:
13
+ import safetensors.torch as safetensors_torch
14
+ except ModuleNotFoundError as exc:
15
+ raise ModuleNotFoundError(
16
+ "load_optimal_grid_2_256 requires safetensors"
17
+ ) from exc
18
+
19
+ return safetensors_torch.load_file(
20
  os.path.join(PKG_PATH, "grids.safetensors"), device=device
21
  )["2_256"].to(dtype)
22