|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..utils import is_accelerate_available, is_torch_available, logging |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
if is_accelerate_available(): |
|
|
from accelerate import init_empty_weights |
|
|
|
|
|
import re |
|
|
from contextlib import contextmanager |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
FP4_VALUES = [ |
|
|
+0.0, |
|
|
+0.5, |
|
|
+1.0, |
|
|
+1.5, |
|
|
+2.0, |
|
|
+3.0, |
|
|
+4.0, |
|
|
+6.0, |
|
|
-0.0, |
|
|
-0.5, |
|
|
-1.0, |
|
|
-1.5, |
|
|
-2.0, |
|
|
-3.0, |
|
|
-4.0, |
|
|
-6.0, |
|
|
] |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def on_device(dev): |
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
if isinstance(dev, torch.Tensor): |
|
|
dev = dev.device |
|
|
elif isinstance(dev, str): |
|
|
dev = torch.device(dev) |
|
|
dev_type = getattr(dev, "type", None) |
|
|
if dev_type == "cuda": |
|
|
with torch.cuda.device(dev): |
|
|
yield |
|
|
return |
|
|
if dev_type == "xpu" and hasattr(torch, "xpu"): |
|
|
with torch.xpu.device(dev): |
|
|
yield |
|
|
return |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
def quantize_to_mxfp4(w, triton_kernels_hub): |
|
|
downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch |
|
|
w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1) |
|
|
return w, w_scale |
|
|
|
|
|
|
|
|
def swizzle_mxfp4(w, w_scale, triton_kernels_hub): |
|
|
""" |
|
|
Changes the layout of the tensors depending on the hardware |
|
|
""" |
|
|
FP4, convert_layout, wrap_torch_tensor = ( |
|
|
triton_kernels_hub.tensor.FP4, |
|
|
triton_kernels_hub.tensor.convert_layout, |
|
|
triton_kernels_hub.tensor.wrap_torch_tensor, |
|
|
) |
|
|
layout = triton_kernels_hub.tensor_details.layout |
|
|
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout |
|
|
|
|
|
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) |
|
|
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) |
|
|
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) |
|
|
return w, w_scale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_moe_packed_tensors( |
|
|
blocks, |
|
|
scales, |
|
|
*, |
|
|
dtype: torch.dtype = torch.bfloat16, |
|
|
rows_per_chunk: int = 32768 * 1024, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward |
|
|
pass of GPT_OSS. |
|
|
""" |
|
|
import math |
|
|
|
|
|
|
|
|
if not blocks.is_cuda and torch.cuda.is_available(): |
|
|
blocks = blocks.cuda() |
|
|
scales = scales.cuda() |
|
|
|
|
|
scales = scales.to(torch.int32) - 127 |
|
|
|
|
|
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}" |
|
|
|
|
|
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) |
|
|
|
|
|
*prefix_shape, G, B = blocks.shape |
|
|
rows_total = math.prod(prefix_shape) * G |
|
|
|
|
|
blocks = blocks.reshape(rows_total, B) |
|
|
scales = scales.reshape(rows_total, 1) |
|
|
|
|
|
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) |
|
|
|
|
|
for r0 in range(0, rows_total, rows_per_chunk): |
|
|
r1 = min(r0 + rows_per_chunk, rows_total) |
|
|
|
|
|
blk = blocks[r0:r1] |
|
|
exp = scales[r0:r1] |
|
|
|
|
|
|
|
|
idx_lo = (blk & 0x0F).to(torch.long) |
|
|
idx_hi = (blk >> 4).to(torch.long) |
|
|
|
|
|
sub = out[r0:r1] |
|
|
sub[:, 0::2] = lut[idx_lo] |
|
|
sub[:, 1::2] = lut[idx_hi] |
|
|
|
|
|
torch.ldexp(sub, exp, out=sub) |
|
|
del idx_lo, idx_hi, blk, exp, sub |
|
|
|
|
|
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) |
|
|
del blocks, scales, lut |
|
|
return out.transpose(1, 2).contiguous() |
|
|
|
|
|
|
|
|
class Mxfp4GptOssExperts(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.num_experts = config.num_local_experts |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.gate_up_proj_blocks = nn.Parameter( |
|
|
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8), |
|
|
requires_grad=False, |
|
|
) |
|
|
self.gate_up_proj_scales = nn.Parameter( |
|
|
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8), |
|
|
requires_grad=False, |
|
|
) |
|
|
self.gate_up_proj_bias = nn.Parameter( |
|
|
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False |
|
|
) |
|
|
|
|
|
self.down_proj_blocks = nn.Parameter( |
|
|
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8), |
|
|
requires_grad=False, |
|
|
) |
|
|
self.down_proj_scales = nn.Parameter( |
|
|
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8), |
|
|
requires_grad=False, |
|
|
) |
|
|
self.down_proj_bias = nn.Parameter( |
|
|
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False |
|
|
) |
|
|
self.alpha = 1.702 |
|
|
self.limit = getattr(config, "swiglu_limit", 7.0) |
|
|
self.gate_up_proj_precision_config = None |
|
|
self.down_proj_precision_config = None |
|
|
self.limit = getattr(config, "swiglu_limit", 7.0) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor: |
|
|
FnSpecs, FusedActivation, matmul_ogs = ( |
|
|
triton_kernels_hub.matmul_ogs.FnSpecs, |
|
|
triton_kernels_hub.matmul_ogs.FusedActivation, |
|
|
triton_kernels_hub.matmul_ogs.matmul_ogs, |
|
|
) |
|
|
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn |
|
|
|
|
|
with on_device(hidden_states.device): |
|
|
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2) |
|
|
|
|
|
intermediate_cache1 = matmul_ogs( |
|
|
hidden_states, |
|
|
self.gate_up_proj, |
|
|
self.gate_up_proj_bias.to(torch.float32), |
|
|
routing_data, |
|
|
gather_indx=gather_idx, |
|
|
precision_config=self.gate_up_proj_precision_config, |
|
|
gammas=None, |
|
|
fused_activation=act, |
|
|
) |
|
|
|
|
|
intermediate_cache3 = matmul_ogs( |
|
|
intermediate_cache1, |
|
|
self.down_proj, |
|
|
self.down_proj_bias.to(torch.float32), |
|
|
routing_data, |
|
|
scatter_indx=scatter_idx, |
|
|
precision_config=self.down_proj_precision_config, |
|
|
gammas=routing_data.gate_scal, |
|
|
) |
|
|
return intermediate_cache3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def routing_torch_dist( |
|
|
logits, |
|
|
n_expts_act, |
|
|
): |
|
|
import os |
|
|
|
|
|
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = ( |
|
|
triton_kernels_hub.routing.GatherIndx, |
|
|
triton_kernels_hub.routing.RoutingData, |
|
|
triton_kernels_hub.routing.ScatterIndx, |
|
|
triton_kernels_hub.routing.compute_expt_data_torch, |
|
|
) |
|
|
|
|
|
with on_device(logits.device): |
|
|
world_size = torch.distributed.get_world_size() |
|
|
rank = int(os.environ.get("LOCAL_RANK", "0")) |
|
|
replace_value = -1 |
|
|
|
|
|
n_tokens = logits.shape[0] |
|
|
n_expts_tot = logits.shape[1] |
|
|
|
|
|
n_local_experts = n_expts_tot // world_size |
|
|
local_expert_start = rank * n_local_experts |
|
|
local_expert_end = (rank + 1) * n_local_experts |
|
|
|
|
|
n_gates_pad = n_tokens * n_expts_act |
|
|
|
|
|
def topk(vals, k): |
|
|
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] |
|
|
tk_indx = tk_indx.long() |
|
|
tk_val = torch.take_along_dim(vals, tk_indx, dim=1) |
|
|
return tk_val, tk_indx.int() |
|
|
|
|
|
expt_scal, expt_indx = topk(logits, n_expts_act) |
|
|
expt_scal = torch.softmax(expt_scal, dim=-1) |
|
|
expt_indx, sort_indices = torch.sort(expt_indx, dim=1) |
|
|
expt_scal = torch.gather(expt_scal, 1, sort_indices) |
|
|
|
|
|
|
|
|
expt_scal = expt_scal.reshape(-1) |
|
|
|
|
|
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end] |
|
|
|
|
|
expt_indx = expt_indx.view(-1).to(torch.int32) |
|
|
|
|
|
|
|
|
var = 1000 |
|
|
expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx) |
|
|
topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32) |
|
|
gate_indx = torch.argsort(topk_indx).to(torch.int32) |
|
|
expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value) |
|
|
expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value) |
|
|
|
|
|
gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx) |
|
|
gate_scal = expt_scal[topk_indx] |
|
|
|
|
|
topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx) |
|
|
|
|
|
|
|
|
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) |
|
|
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) |
|
|
|
|
|
expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad) |
|
|
|
|
|
hit_experts = n_expts_act |
|
|
return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx |
|
|
|
|
|
|
|
|
def mlp_forward(self, hidden_states): |
|
|
import torch.distributed as dist |
|
|
|
|
|
if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"): |
|
|
routing = routing_torch_dist |
|
|
else: |
|
|
routing = triton_kernels_hub.routing.routing |
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim) |
|
|
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias) |
|
|
|
|
|
with on_device(router_logits.device): |
|
|
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k) |
|
|
|
|
|
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx) |
|
|
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) |
|
|
return routed_out, router_logits |
|
|
|
|
|
|
|
|
def should_convert_module(current_key_name, patterns): |
|
|
current_key_name_str = ".".join(current_key_name) |
|
|
if not any( |
|
|
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns |
|
|
): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): |
|
|
from ..integrations.tensor_parallel import shard_and_distribute_module |
|
|
|
|
|
model = kwargs.get("model") |
|
|
empty_param = kwargs.get("empty_param") |
|
|
casting_dtype = kwargs.get("casting_dtype") |
|
|
to_contiguous = kwargs.get("to_contiguous") |
|
|
rank = kwargs.get("rank") |
|
|
device_mesh = kwargs.get("device_mesh") |
|
|
|
|
|
for proj in ["gate_up_proj", "down_proj"]: |
|
|
if proj in param_name: |
|
|
if device_mesh is not None: |
|
|
param_value = shard_and_distribute_module( |
|
|
model, |
|
|
param_value, |
|
|
empty_param, |
|
|
dq_param_name, |
|
|
casting_dtype, |
|
|
to_contiguous, |
|
|
rank, |
|
|
device_mesh, |
|
|
) |
|
|
blocks_attr = f"{proj}_blocks" |
|
|
scales_attr = f"{proj}_scales" |
|
|
setattr(module, param_name.rsplit(".", 1)[1], param_value) |
|
|
if hasattr(module, blocks_attr) and hasattr(module, scales_attr): |
|
|
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) |
|
|
if target_device == "cpu" and torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device))) |
|
|
delattr(module, blocks_attr) |
|
|
delattr(module, scales_attr) |
|
|
|
|
|
|
|
|
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs): |
|
|
""" |
|
|
This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`. |
|
|
""" |
|
|
PrecisionConfig, FlexCtx, InFlexData = ( |
|
|
triton_kernels_hub.matmul_ogs.PrecisionConfig, |
|
|
triton_kernels_hub.matmul_ogs.FlexCtx, |
|
|
triton_kernels_hub.matmul_ogs.InFlexData, |
|
|
) |
|
|
from ..integrations.tensor_parallel import shard_and_distribute_module |
|
|
|
|
|
model = kwargs.get("model") |
|
|
empty_param = kwargs.get("empty_param") |
|
|
casting_dtype = kwargs.get("casting_dtype") |
|
|
to_contiguous = kwargs.get("to_contiguous") |
|
|
rank = kwargs.get("rank") |
|
|
device_mesh = kwargs.get("device_mesh") |
|
|
if "blocks" in param_name: |
|
|
proj = param_name.split(".")[-1].split("_blocks")[0] |
|
|
if "scales" in param_name: |
|
|
proj = param_name.split(".")[-1].split("_scales")[0] |
|
|
if device_mesh is not None: |
|
|
shard_and_distribute_module( |
|
|
model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh |
|
|
) |
|
|
else: |
|
|
setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) |
|
|
blocks_attr = f"{proj}_blocks" |
|
|
scales_attr = f"{proj}_scales" |
|
|
blocks = getattr(module, blocks_attr) |
|
|
scales = getattr(module, scales_attr) |
|
|
|
|
|
if blocks.device.type != "meta" and scales.device.type != "meta": |
|
|
local_experts = blocks.size(0) |
|
|
if proj == "gate_up_proj": |
|
|
blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1) |
|
|
else: |
|
|
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) |
|
|
if getattr(target_device, "type", target_device) == "cpu": |
|
|
target_device = "cuda" |
|
|
blocks = blocks.to(target_device).contiguous() |
|
|
scales = scales.to(target_device).contiguous() |
|
|
with on_device(target_device): |
|
|
triton_weight_tensor, weight_scale = swizzle_mxfp4( |
|
|
blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub |
|
|
) |
|
|
|
|
|
|
|
|
if proj == "gate_up_proj": |
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) |
|
|
else: |
|
|
triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) |
|
|
|
|
|
|
|
|
setattr(module, proj, triton_weight_tensor) |
|
|
setattr( |
|
|
module, |
|
|
f"{proj}_precision_config", |
|
|
PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), |
|
|
) |
|
|
|
|
|
|
|
|
delattr(module, scales_attr) |
|
|
delattr(module, blocks_attr) |
|
|
del blocks |
|
|
|
|
|
|
|
|
def _replace_with_mxfp4_linear( |
|
|
model, |
|
|
modules_to_not_convert=None, |
|
|
current_key_name=None, |
|
|
quantization_config=None, |
|
|
has_been_replaced=False, |
|
|
config=None, |
|
|
): |
|
|
if current_key_name is None: |
|
|
current_key_name = [] |
|
|
|
|
|
for name, module in model.named_children(): |
|
|
current_key_name.append(name) |
|
|
if not should_convert_module(current_key_name, modules_to_not_convert): |
|
|
current_key_name.pop(-1) |
|
|
continue |
|
|
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: |
|
|
with init_empty_weights(): |
|
|
model._modules[name] = Mxfp4GptOssExperts(config) |
|
|
has_been_replaced = True |
|
|
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize: |
|
|
from types import MethodType |
|
|
|
|
|
module.forward = MethodType(mlp_forward, module) |
|
|
if len(list(module.children())) > 0: |
|
|
_, has_been_replaced = _replace_with_mxfp4_linear( |
|
|
module, |
|
|
modules_to_not_convert, |
|
|
current_key_name, |
|
|
quantization_config, |
|
|
has_been_replaced=has_been_replaced, |
|
|
config=config, |
|
|
) |
|
|
current_key_name.pop(-1) |
|
|
return model, has_been_replaced |
|
|
|
|
|
|
|
|
def replace_with_mxfp4_linear( |
|
|
model, |
|
|
modules_to_not_convert=None, |
|
|
current_key_name=None, |
|
|
quantization_config=None, |
|
|
config=None, |
|
|
): |
|
|
if quantization_config.dequantize: |
|
|
return model |
|
|
else: |
|
|
from kernels import get_kernel |
|
|
|
|
|
global triton_kernels_hub |
|
|
triton_kernels_hub = get_kernel("kernels-community/triton_kernels") |
|
|
|
|
|
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert |
|
|
|
|
|
if quantization_config.modules_to_not_convert is not None: |
|
|
modules_to_not_convert.extend(quantization_config.modules_to_not_convert) |
|
|
modules_to_not_convert = list(set(modules_to_not_convert)) |
|
|
model, has_been_replaced = _replace_with_mxfp4_linear( |
|
|
model, |
|
|
modules_to_not_convert, |
|
|
current_key_name, |
|
|
quantization_config, |
|
|
config=config, |
|
|
) |
|
|
if not has_been_replaced: |
|
|
logger.warning( |
|
|
"You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." |
|
|
" Please double check your model architecture, or submit an issue on github if you think this is" |
|
|
" a bug." |
|
|
) |
|
|
|
|
|
return model |
|
|
|