mansaripo commited on
Commit
99063a5
Β·
verified Β·
1 Parent(s): 12303b2

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. hf_quant_config.json +15 -0
  3. model.safetensors +2 -2
  4. modeling_cloverlm.py +83 -0
config.json CHANGED
@@ -12,6 +12,7 @@
12
  ]
13
  },
14
  "d_head": 128,
 
15
  "head_dim": 128,
16
  "heads": 28,
17
  "hidden_size": 3584,
 
12
  ]
13
  },
14
  "d_head": 128,
15
+ "dtype": "bfloat16",
16
  "head_dim": 128,
17
  "heads": 28,
18
  "hidden_size": 3584,
hf_quant_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "producer": {
3
+ "name": "cloverlm_converter",
4
+ "version": "1.0"
5
+ },
6
+ "quantization": {
7
+ "quant_algo": "NVFP4",
8
+ "kv_cache_quant_algo": null,
9
+ "group_size": 16,
10
+ "exclude_modules": [
11
+ "emb",
12
+ "linear"
13
+ ]
14
+ }
15
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5802c11b6b024033386dba4cdff8665d48de19850e0e63c31686f44430ca870f
3
- size 16563661264
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0bbc8b129f5affb348c8526847bcba635c16d9e65700d775a0a941bd2e73533
3
+ size 2659361496
modeling_cloverlm.py CHANGED
@@ -11,6 +11,31 @@ from .configuration_cloverlm import CloverLMConfig
11
  from .fake_quartet import FakeQuartetLinear
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def _sphere_norm(X, dim=-1):
16
  return F.normalize(X, dim=dim)
@@ -230,6 +255,64 @@ class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
230
  )
231
  self.post_init()
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
234
  logits = self.transformer(input_ids)
235
 
 
11
  from .fake_quartet import FakeQuartetLinear
12
 
13
 
14
+ # ── NVFP4 dequantization for checkpoint loading ─────────────────────────────
15
+
16
+ def _dequant_nvfp4_state_dict(raw_sd, dtype=torch.bfloat16):
17
+ """Dequantize NVFP4-packed tensors using quartet2's _dq_fp4 on GPU.
18
+
19
+ The micro-scales are stored in cuBLAS blocked layout; quartet2's _dq_fp4
20
+ handles the unblocking correctly.
21
+ """
22
+ from quartet2.linear import _dq_fp4
23
+
24
+ scale2_bases = {k.removesuffix("_scale_2") for k in raw_sd if k.endswith("_scale_2")}
25
+ result = {}
26
+ for key, tensor in raw_sd.items():
27
+ if key.endswith(("_scale", "_scale_2")):
28
+ continue
29
+ if key in scale2_bases:
30
+ fp4 = tensor.cuda()
31
+ scales = raw_sd[f"{key}_scale"].cuda()
32
+ ts = raw_sd[f"{key}_scale_2"].float().item()
33
+ result[key] = _dq_fp4(fp4, scales, ts).to(dtype).cpu()
34
+ else:
35
+ result[key] = tensor.to(dtype) if tensor.is_floating_point() else tensor
36
+ return result
37
+
38
+
39
 
40
  def _sphere_norm(X, dim=-1):
41
  return F.normalize(X, dim=dim)
 
255
  )
256
  self.post_init()
257
 
258
+ @classmethod
259
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
260
+ import os
261
+ from safetensors import safe_open
262
+
263
+ st_path = os.path.join(str(pretrained_model_name_or_path), "model.safetensors")
264
+ if not os.path.exists(st_path):
265
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
266
+
267
+ with safe_open(st_path, framework="pt") as f:
268
+ if not any(k.endswith("_scale_2") for k in f.keys()):
269
+ return super().from_pretrained(
270
+ pretrained_model_name_or_path, *args, **kwargs,
271
+ )
272
+
273
+ from safetensors.torch import load_file
274
+
275
+ config = kwargs.pop("config", None)
276
+ if config is None:
277
+ config = cls.config_class.from_pretrained(
278
+ pretrained_model_name_or_path, trust_remote_code=True,
279
+ )
280
+
281
+ # Apply config overrides from kwargs (e.g. attn_backend, quartet_2_impl)
282
+ for key in list(kwargs.keys()):
283
+ if hasattr(config, key):
284
+ setattr(config, key, kwargs.pop(key))
285
+ kwargs.pop("trust_remote_code", None)
286
+
287
+ target_dtype = kwargs.pop("torch_dtype", None)
288
+ if target_dtype is None:
289
+ target_dtype = torch.bfloat16
290
+ if isinstance(target_dtype, str):
291
+ target_dtype = getattr(torch, target_dtype)
292
+
293
+ device_map = kwargs.pop("device_map", None)
294
+
295
+ raw = load_file(st_path)
296
+ state_dict = _dequant_nvfp4_state_dict(raw, target_dtype)
297
+
298
+ model = cls(config)
299
+ model.load_state_dict(state_dict, strict=False)
300
+ model = model.to(target_dtype)
301
+
302
+ if device_map is not None:
303
+ if isinstance(device_map, str) and device_map != "auto":
304
+ model = model.to(device_map)
305
+ elif isinstance(device_map, dict):
306
+ device = next(iter(device_map.values()))
307
+ model = model.to(device)
308
+ elif device_map == "auto":
309
+ from accelerate import dispatch_model, infer_auto_device_map
310
+ device_map_computed = infer_auto_device_map(model)
311
+ model = dispatch_model(model, device_map=device_map_computed)
312
+
313
+ model.eval()
314
+ return model
315
+
316
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
317
  logits = self.transformer(input_ids)
318