fm1320 commited on
Commit
3b2c714
·
verified ·
1 Parent(s): 881e8e8

Rewrite README for the renamed FlashNorm (weightless) repo

Browse files
Files changed (1) hide show
  1. README.md +23 -33
README.md CHANGED
@@ -9,45 +9,39 @@ tags:
9
  pipeline_tag: text-generation
10
  ---
11
 
12
- # SmolLM2-135M-FlashNorm-strict
13
 
14
- **Weightless (strict-mode) FlashNorm checkpoint** of [HuggingFaceTB/SmolLM2-135M](https://huggingface.co/HuggingFaceTB/SmolLM2-135M).
15
 
16
- Mathematically equivalent to the source model. The per-channel normalization weight tensors (`input_layernorm.weight`, `post_attention_layernorm.weight`, `model.norm.weight`) have been folded into the following linear layers and then removed from the state dict entirely.
 
 
17
 
18
- > **This checkpoint does NOT load in stock vLLM today.** vLLM's weight loader raises a `ValueError` because the norm weight tensors are absent. Issue tracking the loader patch: TBD. Use [open-machine/SmolLM2-135M-FlashNorm](https://huggingface.co/open-machine/SmolLM2-135M-FlashNorm) (the compat variant) for a drop-in checkpoint that loads in stock vLLM today.
19
-
20
- This repo exists as a concrete test vector for the upstream patch that would let vLLM accept weightless RMSNorm models.
21
-
22
- ## What is FlashNorm (weightless)?
23
 
24
  An exact reformulation of `RMSNorm -> Linear`:
25
 
26
- - **Fold** the per-channel normalization weight `g` into the following linear layer: `W_star = W @ diag(g)`.
27
- - After folding, the RMSNorm layer has no learnable per-channel scale. It just divides by `rms(x)`.
28
  - The resulting model computes the same output as the original, by Proposition 1 of the FlashNorm paper.
29
 
30
- This repo is a "weightless" variant: the `g` tensor itself is absent from the safetensors, because after the fold the runtime value of `g` is always all-ones (the multiplicative identity). Deleting the tensor saves a small amount of disk space and makes explicit that the runtime never needs to multiply by `g`.
31
-
32
  See the [paper](https://github.com/OpenMachine-ai/transformer-tricks/blob/main/tex/flashNorm.tex) (Section 3.1 and Proposition 1) and the [transformer-tricks](https://github.com/OpenMachine-ai/transformer-tricks) repo for details.
33
 
34
  ## What's different from the source checkpoint
35
 
36
- | Tensor | Source | Compat variant | This (strict) |
37
- |---|---|---|---|
38
- | `model.layers.*.input_layernorm.weight` | learned per-channel `g` | all ones | **absent** |
39
- | `model.layers.*.self_attn.{q,k,v}_proj.weight` | `W` | `W @ diag(g_input_layernorm)` | `W @ diag(g_input_layernorm)` |
40
- | `model.layers.*.post_attention_layernorm.weight` | learned per-channel `g` | all ones | **absent** |
41
- | `model.layers.*.mlp.{gate,up}_proj.weight` | `W` | `W @ diag(g_post_attention_layernorm)` | `W @ diag(g_post_attention_layernorm)` |
42
- | `model.norm.weight` | learned per-channel `g` | all ones | **absent** |
43
 
44
- All dtype conventions match the source (`bfloat16`). Mathematical identity to the source model holds by construction.
45
 
46
  ## Usage
47
 
48
- ### Via `transformer_tricks`
49
-
50
- The `transformer_tricks` package can regenerate this checkpoint locally from the source:
51
 
52
  ```python
53
  import transformer_tricks as tt
@@ -56,30 +50,26 @@ tt.flashify_repo('HuggingFaceTB/SmolLM2-135M', strict=True)
56
 
57
  ### Via HuggingFace Transformers
58
 
59
- HuggingFace Transformers will load this checkpoint with a warning that norm weights were not initialized from the checkpoint, and will default them to the module's init value (ones for `LlamaRMSNorm`). Under this path, the output is correct.
60
-
61
  ```python
62
  from transformers import AutoModelForCausalLM, AutoTokenizer
63
 
64
- tok = AutoTokenizer.from_pretrained('open-machine/SmolLM2-135M-FlashNorm-strict')
65
- model = AutoModelForCausalLM.from_pretrained('open-machine/SmolLM2-135M-FlashNorm-strict')
66
 
67
  ids = tok('Once upon a time there was', return_tensors='pt').input_ids
68
  out = model.generate(ids, max_new_tokens=50, do_sample=False)
69
  print(tok.decode(out[0], skip_special_tokens=True))
70
  ```
71
 
72
- ### Via vLLM
73
-
74
- **Not yet supported.** vLLM's weight loader validates that all declared `nn.Parameter` tensors are present in the safetensors and raises `ValueError` when norm weights are absent.
75
 
76
- Tracking issue for upstream patch: TBD (to be linked once filed).
77
 
78
- Until the patch lands, use [open-machine/SmolLM2-135M-FlashNorm](https://huggingface.co/open-machine/SmolLM2-135M-FlashNorm) (compat variant) which keeps the norm tensors as all-ones and loads in stock vLLM unchanged.
79
 
80
  ## Verification
81
 
82
- Generated from the compat variant by deleting the 61 norm weight tensors (30 layers x 2 norms each + 1 final `model.norm`). All other tensors are byte-identical to the compat checkpoint; inference outputs are therefore identical when the loader defaults absent norm weights to ones.
83
 
84
  ## License
85
 
 
9
  pipeline_tag: text-generation
10
  ---
11
 
12
+ # SmolLM2-135M-FlashNorm
13
 
14
+ FlashNorm-prepared checkpoint of [HuggingFaceTB/SmolLM2-135M](https://huggingface.co/HuggingFaceTB/SmolLM2-135M). Mathematically equivalent to the source model. The per-channel RMSNorm weight tensors (`input_layernorm.weight`, `post_attention_layernorm.weight`, `model.norm.weight`) are folded into the following linear layers and then removed from the state dict entirely.
15
 
16
+ > **Framework support note.** Stock vLLM currently does not load this checkpoint because the norm weight tensors are absent. The upstream patch to accept missing tensors is tracked at: **TBD (vLLM issue link)**. Until the patch lands, use HuggingFace Transformers; it loads this with a warning that norm weights were not initialized and defaults them to ones, which is the correct behavior for FlashNorm.
17
+ >
18
+ > The other two public FlashNorm checkpoints in this org, [Llama-3.2-1B-FlashNorm](https://huggingface.co/open-machine/Llama-3.2-1B-FlashNorm) and [Llama-3.1-8B-FlashNorm](https://huggingface.co/open-machine/Llama-3.1-8B-FlashNorm), are currently still in a compatibility layout where the norm tensors are retained as all-ones. They will be flipped to the same weightless layout as this checkpoint once vLLM's loader supports it.
19
 
20
+ ## What FlashNorm does
 
 
 
 
21
 
22
  An exact reformulation of `RMSNorm -> Linear`:
23
 
24
+ - Fold the per-channel normalization weight `g` into the following linear layer: `W_star = W @ diag(g)`, computed once at checkpoint conversion.
25
+ - After folding, the RMSNorm layer has no learnable per-channel scale. At runtime it simply divides by `rms(x)`.
26
  - The resulting model computes the same output as the original, by Proposition 1 of the FlashNorm paper.
27
 
 
 
28
  See the [paper](https://github.com/OpenMachine-ai/transformer-tricks/blob/main/tex/flashNorm.tex) (Section 3.1 and Proposition 1) and the [transformer-tricks](https://github.com/OpenMachine-ai/transformer-tricks) repo for details.
29
 
30
  ## What's different from the source checkpoint
31
 
32
+ | Tensor | Source | This FlashNorm checkpoint |
33
+ |---|---|---|
34
+ | `model.layers.*.input_layernorm.weight` | learned per-channel `g` | **absent** |
35
+ | `model.layers.*.self_attn.{q,k,v}_proj.weight` | `W` | `W @ diag(g_input_layernorm)` |
36
+ | `model.layers.*.post_attention_layernorm.weight` | learned per-channel `g` | **absent** |
37
+ | `model.layers.*.mlp.{gate,up}_proj.weight` | `W` | `W @ diag(g_post_attention_layernorm)` |
38
+ | `model.norm.weight` | learned per-channel `g` | **absent** |
39
 
40
+ All dtype conventions match the source (`bfloat16`). Mathematical identity to the source holds by construction.
41
 
42
  ## Usage
43
 
44
+ ### Regenerate locally with `transformer_tricks`
 
 
45
 
46
  ```python
47
  import transformer_tricks as tt
 
50
 
51
  ### Via HuggingFace Transformers
52
 
 
 
53
  ```python
54
  from transformers import AutoModelForCausalLM, AutoTokenizer
55
 
56
+ tok = AutoTokenizer.from_pretrained('open-machine/SmolLM2-135M-FlashNorm')
57
+ model = AutoModelForCausalLM.from_pretrained('open-machine/SmolLM2-135M-FlashNorm')
58
 
59
  ids = tok('Once upon a time there was', return_tensors='pt').input_ids
60
  out = model.generate(ids, max_new_tokens=50, do_sample=False)
61
  print(tok.decode(out[0], skip_special_tokens=True))
62
  ```
63
 
64
+ A warning about missing norm weights is expected; Transformers defaults those to ones, which is the correct value for a FlashNorm checkpoint.
 
 
65
 
66
+ ### Via vLLM
67
 
68
+ Not yet supported. See the tracking issue linked above.
69
 
70
  ## Verification
71
 
72
+ Under fp32 inference, greedy generation from this checkpoint is bit-identical to the source SmolLM2-135M model. Under fp16 inference the output is within benchmark noise (see the Quality table in Section 5 of the paper).
73
 
74
  ## License
75