File size: 18,439 Bytes
6f0b660 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 |
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
import torch
from torch import nn
if is_accelerate_available():
from accelerate import init_empty_weights
import re
from contextlib import contextmanager
logger = logging.get_logger(__name__)
FP4_VALUES = [
+0.0,
+0.5,
+1.0,
+1.5,
+2.0,
+3.0,
+4.0,
+6.0,
-0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
@contextmanager
def on_device(dev):
if is_torch_available():
import torch
if isinstance(dev, torch.Tensor):
dev = dev.device
elif isinstance(dev, str):
dev = torch.device(dev)
dev_type = getattr(dev, "type", None)
if dev_type == "cuda":
with torch.cuda.device(dev):
yield
return
if dev_type == "xpu" and hasattr(torch, "xpu"):
with torch.xpu.device(dev):
yield
return
# other: CPU
yield
# Copied from GPT_OSS repo and vllm
def quantize_to_mxfp4(w, triton_kernels_hub):
downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1)
return w, w_scale
def swizzle_mxfp4(w, w_scale, triton_kernels_hub):
"""
Changes the layout of the tensors depending on the hardware
"""
FP4, convert_layout, wrap_torch_tensor = (
triton_kernels_hub.tensor.FP4,
triton_kernels_hub.tensor.convert_layout,
triton_kernels_hub.tensor.wrap_torch_tensor,
)
layout = triton_kernels_hub.tensor_details.layout
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
return w, w_scale
# Copied from GPT_OSS repo
# TODO: Add absolute link when the repo is public
def convert_moe_packed_tensors(
blocks,
scales,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
) -> torch.Tensor:
"""
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
pass of GPT_OSS.
"""
import math
# Check if blocks and scales are on CPU, and move to GPU if so
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
*prefix_shape, G, B = blocks.shape
rows_total = math.prod(prefix_shape) * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
for r0 in range(0, rows_total, rows_per_chunk):
r1 = min(r0 + rows_per_chunk, rows_total)
blk = blocks[r0:r1]
exp = scales[r0:r1]
# nibble indices -> int64
idx_lo = (blk & 0x0F).to(torch.long)
idx_hi = (blk >> 4).to(torch.long)
sub = out[r0:r1]
sub[:, 0::2] = lut[idx_lo]
sub[:, 1::2] = lut[idx_hi]
torch.ldexp(sub, exp, out=sub)
del idx_lo, idx_hi, blk, exp, sub
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
del blocks, scales, lut
return out.transpose(1, 2).contiguous()
class Mxfp4GptOssExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.gate_up_proj_blocks = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
requires_grad=False,
)
self.gate_up_proj_scales = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
requires_grad=False,
)
self.gate_up_proj_bias = nn.Parameter(
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
)
self.down_proj_blocks = nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
requires_grad=False,
)
self.down_proj_scales = nn.Parameter(
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
requires_grad=False,
)
self.down_proj_bias = nn.Parameter(
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
)
self.alpha = 1.702
self.limit = getattr(config, "swiglu_limit", 7.0)
self.gate_up_proj_precision_config = None
self.down_proj_precision_config = None
self.limit = getattr(config, "swiglu_limit", 7.0)
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
FnSpecs, FusedActivation, matmul_ogs = (
triton_kernels_hub.matmul_ogs.FnSpecs,
triton_kernels_hub.matmul_ogs.FusedActivation,
triton_kernels_hub.matmul_ogs.matmul_ogs,
)
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
with on_device(hidden_states.device):
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
intermediate_cache1 = matmul_ogs(
hidden_states,
self.gate_up_proj,
self.gate_up_proj_bias.to(torch.float32),
routing_data,
gather_indx=gather_idx,
precision_config=self.gate_up_proj_precision_config,
gammas=None,
fused_activation=act,
)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
self.down_proj,
self.down_proj_bias.to(torch.float32),
routing_data,
scatter_indx=scatter_idx,
precision_config=self.down_proj_precision_config,
gammas=routing_data.gate_scal,
)
return intermediate_cache3
# Adapted from GPT_OSS repo
# TODO: Add absolute link when the repo is public
def routing_torch_dist(
logits,
n_expts_act,
):
import os
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
triton_kernels_hub.routing.GatherIndx,
triton_kernels_hub.routing.RoutingData,
triton_kernels_hub.routing.ScatterIndx,
triton_kernels_hub.routing.compute_expt_data_torch,
)
with on_device(logits.device):
world_size = torch.distributed.get_world_size()
rank = int(os.environ.get("LOCAL_RANK", "0"))
replace_value = -1
n_tokens = logits.shape[0]
n_expts_tot = logits.shape[1]
n_local_experts = n_expts_tot // world_size
local_expert_start = rank * n_local_experts
local_expert_end = (rank + 1) * n_local_experts
n_gates_pad = n_tokens * n_expts_act
def topk(vals, k):
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
tk_indx = tk_indx.long()
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
return tk_val, tk_indx.int()
expt_scal, expt_indx = topk(logits, n_expts_act)
expt_scal = torch.softmax(expt_scal, dim=-1)
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
expt_scal = torch.gather(expt_scal, 1, sort_indices)
# Flatten and mask for local experts
expt_scal = expt_scal.reshape(-1)
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
expt_indx = expt_indx.view(-1).to(torch.int32)
# we use a large value to replace the indices that are not in the local expert range
var = 1000
expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
gate_indx = torch.argsort(topk_indx).to(torch.int32)
expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
gate_scal = expt_scal[topk_indx]
topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
# # Routing metadata for local expert computation
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
hit_experts = n_expts_act
return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx
def mlp_forward(self, hidden_states):
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"):
routing = routing_torch_dist
else:
routing = triton_kernels_hub.routing.routing
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
with on_device(router_logits.device):
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
return routed_out, router_logits
def should_convert_module(current_key_name, patterns):
current_key_name_str = ".".join(current_key_name)
if not any(
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
):
return True
return False
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
from ..integrations.tensor_parallel import shard_and_distribute_module
model = kwargs.get("model")
empty_param = kwargs.get("empty_param")
casting_dtype = kwargs.get("casting_dtype")
to_contiguous = kwargs.get("to_contiguous")
rank = kwargs.get("rank")
device_mesh = kwargs.get("device_mesh")
for proj in ["gate_up_proj", "down_proj"]:
if proj in param_name:
if device_mesh is not None:
param_value = shard_and_distribute_module(
model,
param_value,
empty_param,
dq_param_name,
casting_dtype,
to_contiguous,
rank,
device_mesh,
)
blocks_attr = f"{proj}_blocks"
scales_attr = f"{proj}_scales"
setattr(module, param_name.rsplit(".", 1)[1], param_value)
if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
if target_device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
delattr(module, blocks_attr)
delattr(module, scales_attr)
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
"""
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
"""
PrecisionConfig, FlexCtx, InFlexData = (
triton_kernels_hub.matmul_ogs.PrecisionConfig,
triton_kernels_hub.matmul_ogs.FlexCtx,
triton_kernels_hub.matmul_ogs.InFlexData,
)
from ..integrations.tensor_parallel import shard_and_distribute_module
model = kwargs.get("model")
empty_param = kwargs.get("empty_param")
casting_dtype = kwargs.get("casting_dtype")
to_contiguous = kwargs.get("to_contiguous")
rank = kwargs.get("rank")
device_mesh = kwargs.get("device_mesh")
if "blocks" in param_name:
proj = param_name.split(".")[-1].split("_blocks")[0]
if "scales" in param_name:
proj = param_name.split(".")[-1].split("_scales")[0]
if device_mesh is not None:
shard_and_distribute_module(
model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
)
else:
setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
blocks_attr = f"{proj}_blocks"
scales_attr = f"{proj}_scales"
blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt
scales = getattr(module, scales_attr)
# Check if both blocks and scales both not on meta device
if blocks.device.type != "meta" and scales.device.type != "meta":
local_experts = blocks.size(0)
if proj == "gate_up_proj":
blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
else:
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
if getattr(target_device, "type", target_device) == "cpu":
target_device = "cuda"
blocks = blocks.to(target_device).contiguous()
scales = scales.to(target_device).contiguous()
with on_device(target_device):
triton_weight_tensor, weight_scale = swizzle_mxfp4(
blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
)
# need to overwrite the shapes for the kernels
if proj == "gate_up_proj":
triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
else:
triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
# triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
setattr(module, proj, triton_weight_tensor)
setattr(
module,
f"{proj}_precision_config",
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
)
# delete blocks and scales
delattr(module, scales_attr)
delattr(module, blocks_attr)
del blocks
def _replace_with_mxfp4_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
config=None,
):
if current_key_name is None:
current_key_name = []
for name, module in model.named_children():
current_key_name.append(name)
if not should_convert_module(current_key_name, modules_to_not_convert):
current_key_name.pop(-1)
continue
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
with init_empty_weights():
model._modules[name] = Mxfp4GptOssExperts(config)
has_been_replaced = True
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
from types import MethodType
module.forward = MethodType(mlp_forward, module)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_mxfp4_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
config=config,
)
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_mxfp4_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
config=None,
):
if quantization_config.dequantize:
return model
else:
from kernels import get_kernel
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_mxfp4_linear(
model,
modules_to_not_convert,
current_key_name,
quantization_config,
config=config,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model
|