Fix absolute imports
Browse files- flake.lock +7 -6
- flake.nix +1 -1
- torch-ext/quantization/utils/marlin_utils.py +3 -4
- torch-ext/quantization/utils/marlin_utils_fp4.py +6 -7
- torch-ext/quantization/utils/marlin_utils_fp8.py +3 -3
- torch-ext/quantization/utils/marlin_utils_test.py +1 -2
- torch-ext/quantization/utils/marlin_utils_test_24.py +1 -2
- torch-ext/quantization/utils/quant_utils.py +1 -1
flake.lock
CHANGED
|
@@ -73,11 +73,11 @@
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
-
"lastModified":
|
| 77 |
-
"narHash": "sha256-
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
-
"rev": "
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
@@ -98,15 +98,16 @@
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
-
"lastModified":
|
| 102 |
-
"narHash": "sha256-
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
-
"rev": "
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
| 109 |
"owner": "huggingface",
|
|
|
|
| 110 |
"repo": "kernel-builder",
|
| 111 |
"type": "github"
|
| 112 |
}
|
|
|
|
| 73 |
"nixpkgs": "nixpkgs"
|
| 74 |
},
|
| 75 |
"locked": {
|
| 76 |
+
"lastModified": 1751968576,
|
| 77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
| 78 |
"owner": "huggingface",
|
| 79 |
"repo": "hf-nix",
|
| 80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
| 81 |
"type": "github"
|
| 82 |
},
|
| 83 |
"original": {
|
|
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
"locked": {
|
| 101 |
+
"lastModified": 1751968677,
|
| 102 |
+
"narHash": "sha256-5gtVPN6uk+H3yq2gJRDjSTcaVSgGJZjbMALlO6TBcT8=",
|
| 103 |
"owner": "huggingface",
|
| 104 |
"repo": "kernel-builder",
|
| 105 |
+
"rev": "54eea2ce49889202e7018792f407046e36f89bc5",
|
| 106 |
"type": "github"
|
| 107 |
},
|
| 108 |
"original": {
|
| 109 |
"owner": "huggingface",
|
| 110 |
+
"ref": "get-kernel-check",
|
| 111 |
"repo": "kernel-builder",
|
| 112 |
"type": "github"
|
| 113 |
}
|
flake.nix
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
description = "Flake for quantization kernels";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
|
|
|
| 2 |
description = "Flake for quantization kernels";
|
| 3 |
|
| 4 |
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder/get-kernel-check";
|
| 6 |
};
|
| 7 |
|
| 8 |
outputs =
|
torch-ext/quantization/utils/marlin_utils.py
CHANGED
|
@@ -6,8 +6,7 @@ from typing import Optional
|
|
| 6 |
import numpy
|
| 7 |
import torch
|
| 8 |
|
| 9 |
-
import
|
| 10 |
-
from quantization.scalar_type import ScalarType, scalar_types
|
| 11 |
|
| 12 |
from .quant_utils import pack_cols, unpack_cols
|
| 13 |
|
|
@@ -383,7 +382,7 @@ def apply_gptq_marlin_linear(
|
|
| 383 |
device=input.device,
|
| 384 |
dtype=input.dtype)
|
| 385 |
|
| 386 |
-
output =
|
| 387 |
None,
|
| 388 |
weight,
|
| 389 |
weight_scale,
|
|
@@ -429,7 +428,7 @@ def apply_awq_marlin_linear(
|
|
| 429 |
device=input.device,
|
| 430 |
dtype=input.dtype)
|
| 431 |
|
| 432 |
-
output =
|
| 433 |
None,
|
| 434 |
weight,
|
| 435 |
weight_scale,
|
|
|
|
| 6 |
import numpy
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from .. import ScalarType, gptq_marlin_gemm, scalar_types
|
|
|
|
| 10 |
|
| 11 |
from .quant_utils import pack_cols, unpack_cols
|
| 12 |
|
|
|
|
| 382 |
device=input.device,
|
| 383 |
dtype=input.dtype)
|
| 384 |
|
| 385 |
+
output = gptq_marlin_gemm(reshaped_x,
|
| 386 |
None,
|
| 387 |
weight,
|
| 388 |
weight_scale,
|
|
|
|
| 428 |
device=input.device,
|
| 429 |
dtype=input.dtype)
|
| 430 |
|
| 431 |
+
output = gptq_marlin_gemm(reshaped_x,
|
| 432 |
None,
|
| 433 |
weight,
|
| 434 |
weight_scale,
|
torch-ext/quantization/utils/marlin_utils_fp4.py
CHANGED
|
@@ -5,12 +5,11 @@ from typing import Optional
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
-
import
|
| 9 |
-
|
| 10 |
from .marlin_utils import (
|
| 11 |
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
| 12 |
should_use_atomic_add_reduce)
|
| 13 |
-
from
|
| 14 |
|
| 15 |
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
| 16 |
|
|
@@ -90,7 +89,7 @@ def apply_fp4_marlin_linear(
|
|
| 90 |
device=input.device,
|
| 91 |
dtype=input.dtype)
|
| 92 |
|
| 93 |
-
output =
|
| 94 |
c=None,
|
| 95 |
b_q_weight=weight,
|
| 96 |
b_scales=weight_scale,
|
|
@@ -135,7 +134,7 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
|
| 135 |
perm = torch.empty(0, dtype=torch.int, device=device)
|
| 136 |
qweight = layer.weight.view(torch.int32).T.contiguous()
|
| 137 |
|
| 138 |
-
marlin_qweight =
|
| 139 |
perm=perm,
|
| 140 |
size_k=part_size_k,
|
| 141 |
size_n=part_size_n,
|
|
@@ -192,7 +191,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
|
|
| 192 |
for i in range(e):
|
| 193 |
qweight = weight[i].view(torch.int32).T.contiguous()
|
| 194 |
|
| 195 |
-
marlin_qweight =
|
| 196 |
perm=perm,
|
| 197 |
size_k=size_k,
|
| 198 |
size_n=size_n,
|
|
@@ -263,7 +262,7 @@ def rand_marlin_weight_fp4_like(weight, group_size):
|
|
| 263 |
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
| 264 |
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
| 265 |
|
| 266 |
-
marlin_qweight =
|
| 267 |
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
| 268 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
| 269 |
size_k=size_k,
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from .. import gptq_marlin_gemm, gptq_marlin_repack
|
|
|
|
| 9 |
from .marlin_utils import (
|
| 10 |
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
|
| 11 |
should_use_atomic_add_reduce)
|
| 12 |
+
from ..scalar_type import scalar_types
|
| 13 |
|
| 14 |
FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16]
|
| 15 |
|
|
|
|
| 89 |
device=input.device,
|
| 90 |
dtype=input.dtype)
|
| 91 |
|
| 92 |
+
output = gptq_marlin_gemm(a=reshaped_x,
|
| 93 |
c=None,
|
| 94 |
b_q_weight=weight,
|
| 95 |
b_scales=weight_scale,
|
|
|
|
| 134 |
perm = torch.empty(0, dtype=torch.int, device=device)
|
| 135 |
qweight = layer.weight.view(torch.int32).T.contiguous()
|
| 136 |
|
| 137 |
+
marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
|
| 138 |
perm=perm,
|
| 139 |
size_k=part_size_k,
|
| 140 |
size_n=part_size_n,
|
|
|
|
| 191 |
for i in range(e):
|
| 192 |
qweight = weight[i].view(torch.int32).T.contiguous()
|
| 193 |
|
| 194 |
+
marlin_qweight = gptq_marlin_repack(b_q_weight=qweight,
|
| 195 |
perm=perm,
|
| 196 |
size_k=size_k,
|
| 197 |
size_n=size_n,
|
|
|
|
| 262 |
weight_ref = weight_ref * global_scale.to(weight.dtype) * \
|
| 263 |
scales.repeat_interleave(group_size, 1).to(weight.dtype)
|
| 264 |
|
| 265 |
+
marlin_qweight = gptq_marlin_repack(
|
| 266 |
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
|
| 267 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
| 268 |
size_k=size_k,
|
torch-ext/quantization/utils/marlin_utils_fp8.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Optional
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
-
import
|
| 9 |
|
| 10 |
from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
|
| 11 |
|
|
@@ -51,7 +51,7 @@ def apply_fp8_marlin_linear(
|
|
| 51 |
device=input.device,
|
| 52 |
dtype=input.dtype)
|
| 53 |
|
| 54 |
-
output =
|
| 55 |
c=None,
|
| 56 |
b_q_weight=weight,
|
| 57 |
b_scales=weight_scale,
|
|
@@ -104,7 +104,7 @@ def marlin_quant_fp8_torch(weight, group_size):
|
|
| 104 |
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
| 105 |
|
| 106 |
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
| 107 |
-
marlin_qweight =
|
| 108 |
b_q_weight=packed_weight,
|
| 109 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
| 110 |
size_k=size_k,
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from .. import gptq_marlin_gemm, gptq_marlin_repack
|
| 9 |
|
| 10 |
from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales
|
| 11 |
|
|
|
|
| 51 |
device=input.device,
|
| 52 |
dtype=input.dtype)
|
| 53 |
|
| 54 |
+
output = gptq_marlin_gemm(a=reshaped_x,
|
| 55 |
c=None,
|
| 56 |
b_q_weight=weight,
|
| 57 |
b_scales=weight_scale,
|
|
|
|
| 104 |
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
|
| 105 |
|
| 106 |
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
|
| 107 |
+
marlin_qweight = gptq_marlin_repack(
|
| 108 |
b_q_weight=packed_weight,
|
| 109 |
perm=torch.empty(0, dtype=torch.int, device=device),
|
| 110 |
size_k=size_k,
|
torch-ext/quantization/utils/marlin_utils_test.py
CHANGED
|
@@ -5,8 +5,7 @@ from typing import List, Optional
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
|
| 8 |
-
from
|
| 9 |
-
|
| 10 |
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
| 11 |
from .quant_utils import (
|
| 12 |
get_pack_factor,
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from ..scalar_type import ScalarType
|
|
|
|
| 9 |
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points
|
| 10 |
from .quant_utils import (
|
| 11 |
get_pack_factor,
|
torch-ext/quantization/utils/marlin_utils_test_24.py
CHANGED
|
@@ -6,8 +6,7 @@ from typing import List
|
|
| 6 |
import numpy
|
| 7 |
import torch
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
from .marlin_utils_test import marlin_weights
|
| 12 |
from .quant_utils import gptq_quantize_weights
|
| 13 |
|
|
|
|
| 6 |
import numpy
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
from ..scalar_type import ScalarType
|
|
|
|
| 10 |
from .marlin_utils_test import marlin_weights
|
| 11 |
from .quant_utils import gptq_quantize_weights
|
| 12 |
|
torch-ext/quantization/utils/quant_utils.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import List, Optional
|
|
| 5 |
import numpy
|
| 6 |
import torch
|
| 7 |
|
| 8 |
-
from
|
| 9 |
|
| 10 |
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
| 11 |
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
|
|
| 5 |
import numpy
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from ..scalar_type import ScalarType, scalar_types
|
| 9 |
|
| 10 |
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
| 11 |
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|