CloverLM / modeling_cloverlm.py
mansaripo's picture
Upload folder using huggingface_hub
e5f90f7 verified
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_cloverlm import CloverLMConfig
from .fake_quartet import FakeQuartetLinear
# ── NVFP4 dequantization for checkpoint loading ─────────────────────────────
def _dequant_nvfp4_state_dict(raw_sd, dtype=torch.bfloat16):
"""Dequantize NVFP4-packed tensors using quartet2's _dq_fp4 on GPU.
The micro-scales are stored in cuBLAS blocked layout; quartet2's _dq_fp4
handles the unblocking correctly.
"""
from quartet2.linear import _dq_fp4
scale2_bases = {k.removesuffix("_scale_2") for k in raw_sd if k.endswith("_scale_2")}
result = {}
for key, tensor in raw_sd.items():
if key.endswith(("_scale", "_scale_2")):
continue
if key in scale2_bases:
fp4 = tensor.cuda()
scales = raw_sd[f"{key}_scale"].cuda()
ts = raw_sd[f"{key}_scale_2"].float().item()
result[key] = _dq_fp4(fp4, scales, ts).to(dtype).cpu()
else:
result[key] = tensor.to(dtype) if tensor.is_floating_point() else tensor
return result
def _sphere_norm(X, dim=-1):
return F.normalize(X, dim=dim)
class _ReLU2(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
def _make_linear(in_f, out_f, bias, quartet_2_impl):
if quartet_2_impl == "pseudoquant":
return FakeQuartetLinear(in_f, out_f, bias)
elif quartet_2_impl == "quartet2":
try:
from quartet2.linear import Quartet_II_linear
except ImportError as e:
e.add_note("Quartet_II_linear import failed. Install the latest quartet2 from https://github.com/IST-DASLab/Quartet-II")
raise e
return Quartet_II_linear(in_f, out_f, bias)
elif quartet_2_impl in ("bf16", None, ""):
return nn.Linear(in_f, out_f, bias=bias)
else:
raise ValueError(f"Unsupported quartet_2_impl: {quartet_2_impl}")
def _build_rope(context, d_head, device):
ms = torch.arange(context, device=device, dtype=torch.float32)
js = torch.arange(d_head // 2, device=device, dtype=torch.float32)
theta = 1.0 / (1024.0 ** (2.0 * js / d_head))
phi = ms[:, None] @ theta[None, :]
cos = torch.cos(phi).repeat_interleave(2, dim=1)
sin = torch.sin(phi).repeat_interleave(2, dim=1)
return torch.stack((cos, sin))
def _apply_rope(X, rope):
X_ = torch.empty_like(X)
X_[..., 0::2] = -X[..., 1::2]
X_[..., 1::2] = X[..., 0::2]
return (X * rope[0] + X_ * rope[1]).to(X.dtype)
class _MLP(nn.Module):
def __init__(self, d, d_hidden, quartet_2_impl):
super().__init__()
self.l1 = nn.Sequential(_make_linear(d, d_hidden, False, quartet_2_impl), _ReLU2())
self.l2 = _make_linear(d_hidden, d, False, quartet_2_impl)
def forward(self, x):
return self.l2(self.l1(x))
class MHSA(nn.Module):
def __init__(self, heads, d_head, ratio, quartet_2_impl):
super().__init__()
self.heads = heads
self.d_head = d_head
self.d = heads * d_head
self.groups = heads // ratio
d_kv = self.groups * d_head
self.lq = _make_linear(self.d, self.d, False, quartet_2_impl)
self.lk = _make_linear(self.d, d_kv, False, quartet_2_impl)
self.lv = _make_linear(self.d, d_kv, False, quartet_2_impl)
self.lo = _make_linear(self.d, self.d, False, quartet_2_impl)
self.scale = nn.Parameter(torch.full((1, heads, 1, 1), sqrt(d_head)))
def forward(self, X, rope, attn_backend):
B = X.shape[0] if X.dim() == 3 else 1
ctx = X.shape[-2]
Q = self.lq(X).unflatten(-1, (self.heads, self.d_head)).movedim(-3, -2)
K = self.lk(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
V = self.lv(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
Q = _apply_rope(Q, rope)
K = _apply_rope(K, rope)
Q = _sphere_norm(Q)
K = _sphere_norm(K)
Q_shape = Q.shape
Q = self.scale * Q
Q = Q.reshape(Q_shape)
if attn_backend == "pytorch":
K = K.repeat_interleave(self.heads // self.groups, dim=-3)
V = V.repeat_interleave(self.heads // self.groups, dim=-3)
Y = F.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1.0)
Y = Y.movedim(-3, -2).flatten(-2, -1)
elif attn_backend in ("flash2", "flash3", "flash4"):
Q = Q.movedim(-3, -2).reshape(-1, ctx, self.heads, self.d_head)
K = K.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
V = V.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
if attn_backend == "flash2":
try:
import flash_attn
except ImportError as e:
e.add_note(f"Can't run `attn_backend=flash2` because can't import flash_attn")
raise e
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
elif attn_backend == "flash3":
import importlib
try:
_fa3 = importlib.import_module("flash_attn_interface")
except ImportError as e:
e.add_note(f"Can't run `attn_backend=flash3` because can't import flash_attn_interface")
raise e
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
elif attn_backend == "flash4":
import importlib
try:
_fa4 = importlib.import_module("flash_attn.cute")
except ImportError as e:
e.add_note(f"Can't run `attn_backend=flash4` because can't import flash_attn.cute")
raise e
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
Y = Y.to(Q.dtype).flatten(-2, -1)
return self.lo(Y)
class _Block(nn.Module):
def __init__(self, heads, d_head, ratio, quartet_2_impl):
super().__init__()
d = heads * d_head
self.mhsa = MHSA(heads, d_head, ratio, quartet_2_impl)
self.out_att_norm = nn.RMSNorm(d, elementwise_affine=True)
self.mlp = _MLP(d, 4 * d, quartet_2_impl)
self.out_mlp_norm = nn.RMSNorm(d, elementwise_affine=True)
def forward(self, X, rope, attn_backend):
Y = self.out_att_norm(self.mhsa(X, rope, attn_backend))
Y = X + Y
Z = self.out_mlp_norm(self.mlp(Y))
return Y + Z
class _Transformer(nn.Module):
def __init__(self, vocab_size, num_blocks, heads, d_head, ratio,
max_context, std, quartet_2_impl, weight_tying, attn_backend):
super().__init__()
self.d_head = d_head
self.attn_backend = attn_backend
d = heads * d_head
self.emb = nn.Embedding(vocab_size, d)
self.blocks = nn.Sequential(*[
_Block(heads, d_head, ratio, quartet_2_impl) for _ in range(num_blocks)
])
self.out_norm = nn.RMSNorm(d, elementwise_affine=True)
self.linear = nn.Linear(d, vocab_size, bias=False)
if weight_tying:
self.emb.weight = self.linear.weight
for name, p in self.named_parameters():
parent_name, _, suffix = name.rpartition(".")
parent = self.get_submodule(parent_name)
if isinstance(parent, (nn.Linear, nn.Embedding)) and suffix == "weight":
nn.init.normal_(p, 0, std)
elif isinstance(parent, nn.RMSNorm) and suffix == "weight":
nn.init.ones_(p)
elif p.ndim == 4:
nn.init.constant_(p, sqrt(d_head))
if quartet_2_impl:
for m in self.modules():
if isinstance(m, (nn.LayerNorm, nn.RMSNorm, nn.Embedding)):
m.to(torch.bfloat16)
def forward(self, ids):
ctx = ids.shape[-1]
rope = _build_rope(ctx, self.d_head, device=ids.device)
X = self.emb(ids)
for block in self.blocks:
X = block(X, rope, self.attn_backend)
X = self.out_norm(X)
return self.linear(X)
class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
config_class = CloverLMConfig
supports_gradient_checkpointing = False
_no_split_modules = ["_Block"]
_tied_weights_keys = {"transformer.linear.weight": "transformer.emb.weight"}
_tp_plan = {}
def __init__(self, config: CloverLMConfig):
super().__init__(config)
self.transformer = _Transformer(
vocab_size=config.vocab_size,
num_blocks=config.num_blocks,
heads=config.heads,
d_head=config.d_head,
ratio=config.ratio,
max_context=config.max_context,
std=0.02,
quartet_2_impl=config.quartet_2_impl,
weight_tying=config.weight_tying,
attn_backend=config.attn_backend,
)
self.post_init()
@classmethod
def _resolve_safetensors(cls, pretrained_model_name_or_path, **kwargs):
"""Locate model.safetensors for a local dir or Hub repo ID."""
import os
path = str(pretrained_model_name_or_path)
local = os.path.join(path, "model.safetensors")
if os.path.exists(local):
return local
try:
from huggingface_hub import hf_hub_download
return hf_hub_download(
repo_id=path,
filename="model.safetensors",
token=kwargs.get("token"),
)
except Exception:
return None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
import os
from safetensors import safe_open
st_path = cls._resolve_safetensors(pretrained_model_name_or_path, **kwargs)
if st_path is None:
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
with safe_open(st_path, framework="pt") as f:
if not any(k.endswith("_scale_2") for k in f.keys()):
return super().from_pretrained(
pretrained_model_name_or_path, *args, **kwargs,
)
from safetensors.torch import load_file
config = kwargs.pop("config", None)
if config is None:
config = cls.config_class.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True,
)
# Apply config overrides from kwargs (e.g. attn_backend, quartet_2_impl)
for key in list(kwargs.keys()):
if hasattr(config, key):
setattr(config, key, kwargs.pop(key))
kwargs.pop("trust_remote_code", None)
target_dtype = kwargs.pop("torch_dtype", None)
if target_dtype is None:
target_dtype = torch.bfloat16
if isinstance(target_dtype, str):
target_dtype = getattr(torch, target_dtype)
device_map = kwargs.pop("device_map", None)
raw = load_file(st_path)
state_dict = _dequant_nvfp4_state_dict(raw, target_dtype)
model = cls(config)
model.load_state_dict(state_dict, strict=False)
model = model.to(target_dtype)
if device_map is not None:
if isinstance(device_map, str) and device_map != "auto":
model = model.to(device_map)
elif isinstance(device_map, dict):
device = next(iter(device_map.values()))
model = model.to(device)
elif device_map == "auto":
from accelerate import dispatch_model, infer_auto_device_map
device_map_computed = infer_auto_device_map(model)
model = dispatch_model(model, device_map=device_map_computed)
model.eval()
return model
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
logits = self.transformer(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
return CausalLMOutputWithPast(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _supports_default_dynamic_cache(self):
return False