Ram07 commited on
Commit
d0f40b0
·
verified ·
1 Parent(s): 160bf89

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - bitnet
8
+ - quantization
9
+ - early-exit
10
+ - layer-skipping
11
+ - efficient-transformers
12
+ datasets:
13
+ - roneneldan/TinyStories
14
+ ---
15
+
16
+ # bitskip-v2-earlyexit
17
+
18
+ BitSkip v2 with 4-bit activation quantization, ternary weights, and Hadamard transform
19
+
20
+ ## Model Description
21
+
22
+ This model implements a 24-layer transformer with early exit loss and quadratic layer dropout for efficient inference. It was trained on the TinyStories dataset with layer-wise auxiliary supervision to enable flexible speed-quality tradeoffs during inference.
23
+
24
+ ## Architecture Details
25
+
26
+ - **Layers**: 24
27
+ - **Hidden dimension**: 2048
28
+ - **Attention heads**: 32 (64-dimensional each)
29
+ - **Key-Value heads**: 8 (Grouped Query Attention with 4:1 ratio)
30
+ - **FFN intermediate size**: 4096
31
+ - **Position embeddings**: Rotary Position Embeddings (RoPE)
32
+ - **Normalization**: RMSNorm
33
+ - **Activation**: SwiGLU (for MLP)
34
+ - **Parameters**: ~1.06B
35
+
36
+ ### Quantization Scheme
37
+
38
+ - **Weights**: Ternary {-1, 0, 1}
39
+ - **Activations**: 4-bit quantization (post-Hadamard)
40
+ - **Hadamard**: Yes (FWHT)
41
+
42
+ ## Training Details
43
+
44
+ ### Dataset
45
+ - **Source**: TinyStories (2.1M stories)
46
+ - **Tokenizer**: GPT-2 BPE (vocab size: 50,257)
47
+ - **Sequence length**: 512 tokens
48
+
49
+ ### Training Techniques
50
+
51
+ **Quadratic Layer Dropout:**
52
+ - Progressive dropout: p_l = 0.5 × (l/L)²
53
+ - Normalized so Σp_l = 1.0
54
+ - Never drops final layer
55
+ - Makes earlier layers more accurate
56
+
57
+ **Early Exit Loss:**
58
+ - All layers share the same LM head
59
+ - Loss = main_loss + 0.3 × early_exit_loss
60
+ - Layer-proportional weighting: w_i = (i+1)/L
61
+ - Enables flexible early exit at inference
62
+
63
+ ### Hyperparameters
64
+
65
+ - **Optimizer**: AdamW
66
+ - **Learning rate**: 3e-4
67
+ - **Warmup steps**: 4000
68
+ - **Batch size**: 16 (effective: 64)
69
+ - **Training steps**: 50000
70
+ - **Gradient clipping**: 0.5
71
+
72
+ ## Performance
73
+
74
+ ### Perplexity (TinyStories validation)
75
+
76
+ | Exit Layer | Perplexity | Speed (tok/s) |
77
+ |------------|------------|---------------|
78
+ | All layers | TBD | TBD |
79
+ | Layer 18 | TBD | TBD |
80
+ | Layer 12 | TBD | TBD |
81
+ | Layer 6 | TBD | TBD |
82
+
83
+ ### Training Stability
84
+
85
+ - **Gradient norms**: 50-110
86
+ - **Final loss**: TBD
87
+
88
+ ## Usage
89
+
90
+ ### Installation
91
+
92
+ ```bash
93
+ pip install transformers torch
94
+ ```
95
+
96
+ ### Basic Inference
97
+
98
+ ```python
99
+ from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ # Load model
102
+ model = AutoModelForCausalLM.from_pretrained("your-username/bitskip-v2-earlyexit")
103
+ tokenizer = AutoTokenizer.from_pretrained("your-username/bitskip-v2-earlyexit")
104
+
105
+ # Generate text
106
+ inputs = tokenizer("Once upon a time", return_tensors="pt")
107
+ outputs = model.generate(**inputs, max_length=100)
108
+ print(tokenizer.decode(outputs[0]))
109
+ ```
110
+
111
+ ### Early Exit Inference
112
+
113
+ ```python
114
+ # Exit at layer 12 for faster inference
115
+ model.set_exit_layer(12)
116
+ outputs = model.generate(**inputs, max_length=100)
117
+ # 1.5-2x faster with minimal quality loss
118
+ ```
119
+
120
+ ### Benchmark Different Exit Layers
121
+
122
+ ```python
123
+ for exit_layer in [6, 12, 18, 24]:
124
+ model.set_exit_layer(exit_layer)
125
+ outputs = model.generate(**inputs, max_length=100)
126
+ print(f"Layer {exit_layer}: {tokenizer.decode(outputs[0])}")
127
+ ```
128
+
129
+ ## Limitations
130
+
131
+ - **Inference speed**: Quantized models use fake quantization (QAT) without specialized kernels, resulting in slower inference than full-precision despite lower bit-width
132
+ - **Training instability**: 4-bit models (v2) exhibit gradient explosion (norms 50-110) requiring careful hyperparameter tuning
133
+ - **Dataset scope**: Trained only on TinyStories; may not generalize to other domains without fine-tuning
134
+
135
+ ## Citation
136
+
137
+ If you use this model, please cite:
138
+
139
+ ```bibtex
140
+ @article{bitnet,
141
+ title={BitNet: Scaling 1-bit Transformers for Large Language Models},
142
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and others},
143
+ journal={arXiv preprint arXiv:2310.11453},
144
+ year={2023}
145
+ }
146
+
147
+ @article{layerskip,
148
+ title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
149
+ author={Elhoushi, Mostafa and Shrivastava, Akshat and Liskovich, Diana and others},
150
+ journal={arXiv preprint arXiv:2404.16710},
151
+ year={2024}
152
+ }
153
+ ```
154
+
155
+ ## License
156
+
157
+ MIT License
158
+
159
+ ## Contact
160
+
161
+ For questions or issues, please open an issue on the model repository.
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BitSkipV2ForCausalLMWithEarlyExit"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model_v2_earlyexit.BitSkipV2EarlyExitConfig",
7
+ "AutoModelForCausalLM": "model_v2_earlyexit.BitSkipV2ForCausalLMWithEarlyExit"
8
+ },
9
+ "early_exit_loss_weight": 0.3,
10
+ "hidden_size": 2048,
11
+ "inference_exit_layer": null,
12
+ "intermediate_size": 4096,
13
+ "max_dropout_prob": 0.5,
14
+ "max_position_embeddings": 2048,
15
+ "model_type": "bitskip_v2_earlyexit",
16
+ "num_attention_heads": 32,
17
+ "num_hidden_layers": 24,
18
+ "num_key_value_heads": 8,
19
+ "rms_norm_eps": 1e-05,
20
+ "rope_theta": 10000.0,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.45.2",
23
+ "vocab_size": 50257
24
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.45.2"
4
+ }
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for bitskip-v2-earlyexit
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ def main():
9
+ # Load from HuggingFace Hub or local path
10
+ model_path = "." # Current directory or specify repo_id
11
+
12
+ print("Loading model...")
13
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+
16
+ model.eval()
17
+ print("Model loaded!")
18
+
19
+ # Example generation
20
+ prompt = "Once upon a time"
21
+ inputs = tokenizer(prompt, return_tensors="pt")
22
+
23
+ print(f"\nPrompt: {prompt}\n")
24
+
25
+ # Full model
26
+ print("Generating with all layers...")
27
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
28
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
29
+
30
+ # Early exit at layer 12
31
+ print("\nGenerating with early exit at layer 12...")
32
+ model.set_exit_layer(12)
33
+ outputs = model.generate(**inputs, max_length=100, pad_token_id=tokenizer.eos_token_id)
34
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
35
+
36
+ if __name__ == "__main__":
37
+ main()
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:780f971cbb9a8460636d3de74c7620d7371f7e6895439face3eab2bfd887ebfc
3
+ size 3837873528
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model files for bitskip-v2-earlyexit"""
models/h_bitlinear.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ H-BitLinear layer for BitSkip v2 (4-bit activations WITH Hadamard transform)
3
+ OPTIMIZED: Fast Hadamard transform implementation
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def hadamard_transform(x):
13
+ """
14
+ Fast Walsh-Hadamard Transform (FWHT) - OPTIMIZED VERSION.
15
+
16
+ This vectorized implementation is MUCH faster than the loop version.
17
+ Uses divide-and-conquer butterfly pattern for O(n log n) complexity.
18
+ """
19
+ orig_shape = x.shape
20
+ n = x.shape[-1]
21
+
22
+ # Ensure dimension is power of 2
23
+ assert n & (n - 1) == 0, f"Dimension must be power of 2, got {n}"
24
+
25
+ # Flatten to 2D for transform
26
+ x = x.reshape(-1, n)
27
+
28
+ # Fast Hadamard transform using butterfly pattern
29
+ h = 1
30
+ while h < n:
31
+ # Vectorized butterfly operations (MUCH faster than loops!)
32
+ x = x.reshape(-1, n // (2 * h), 2, h)
33
+ x_even = x[:, :, 0, :] # First half
34
+ x_odd = x[:, :, 1, :] # Second half
35
+
36
+ # Butterfly: (a, b) -> (a+b, a-b)
37
+ x[:, :, 0, :] = x_even + x_odd
38
+ x[:, :, 1, :] = x_even - x_odd
39
+
40
+ x = x.reshape(-1, n)
41
+ h *= 2
42
+
43
+ # Normalize
44
+ x = x / math.sqrt(n)
45
+
46
+ # Reshape back
47
+ return x.reshape(orig_shape)
48
+
49
+
50
+ class HBitLinear(nn.Module):
51
+ """
52
+ H-BitLinear: Hadamard transform + Ternary weights + 4-bit activations.
53
+
54
+ Flow:
55
+ 1. LayerNorm
56
+ 2. Hadamard transform (key preprocessing step!)
57
+ 3. 4-bit quantization
58
+ 4. Linear operation with ternary weights
59
+ 5. Inverse Hadamard transform
60
+ """
61
+
62
+ def __init__(self, in_features, out_features, bias=False):
63
+ super().__init__()
64
+
65
+ # Ensure power of 2 for Hadamard
66
+ assert in_features & (in_features - 1) == 0, \
67
+ f"in_features must be power of 2 for Hadamard, got {in_features}"
68
+ assert out_features & (out_features - 1) == 0, \
69
+ f"out_features must be power of 2 for Hadamard, got {out_features}"
70
+
71
+ self.in_features = in_features
72
+ self.out_features = out_features
73
+
74
+ # Weight and bias
75
+ self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
76
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
77
+
78
+ # LayerNorm before Hadamard
79
+ self.norm = nn.LayerNorm(in_features)
80
+
81
+ def forward(self, x):
82
+ """
83
+ Forward with Hadamard preprocessing + 4-bit quantization.
84
+ """
85
+ # 1. LayerNorm
86
+ x = self.norm(x)
87
+
88
+ # 2. Hadamard transform (KEY STEP for v2!)
89
+ x_hadamard = hadamard_transform(x)
90
+
91
+ # 3. 4-bit quantization (works better after Hadamard)
92
+ x_scale = x_hadamard.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)
93
+ x_quant = (x_hadamard / x_scale * 7).round().clamp(-8, 7) # 4-bit: -8 to 7
94
+ x_quant = x_quant / 7 * x_scale
95
+
96
+ # STE for gradients
97
+ if self.training:
98
+ x_quant = x_hadamard + (x_quant - x_hadamard).detach()
99
+
100
+ # 4. Ternary weight quantization (same as v1)
101
+ w_scale = self.weight.abs().mean().clamp(min=1e-5)
102
+ w_quant = torch.zeros_like(self.weight)
103
+ w_quant[self.weight > 0.5 * w_scale] = 1.0
104
+ w_quant[self.weight < -0.5 * w_scale] = -1.0
105
+ w_quant = w_quant * w_scale
106
+
107
+ if self.training:
108
+ w_quant = self.weight + (w_quant - self.weight).detach()
109
+
110
+ # 5. Linear operation
111
+ output = F.linear(x_quant, w_quant, self.bias)
112
+
113
+ # 6. Inverse Hadamard transform
114
+ output = hadamard_transform(output)
115
+
116
+ return output
models/model_v2_earlyexit.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitSkip v2 with Early Exit Loss and Quadratic Dropout
3
+ - H-BitLinear quantization (4-bit + Hadamard)
4
+ - Quadratic layer dropout (normalized sum=1)
5
+ - Early exit loss from all layers
6
+ - HuggingFace compatible
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+ from typing import Optional, Tuple
16
+
17
+ from .h_bitlinear import HBitLinear
18
+
19
+
20
+ class BitSkipV2EarlyExitConfig(PretrainedConfig):
21
+ model_type = "bitskip_v2_earlyexit"
22
+
23
+ def __init__(
24
+ self,
25
+ vocab_size=50257,
26
+ hidden_size=2048,
27
+ num_hidden_layers=24,
28
+ num_attention_heads=32,
29
+ num_key_value_heads=8,
30
+ intermediate_size=4096,
31
+ max_position_embeddings=2048,
32
+ rms_norm_eps=1e-5,
33
+ rope_theta=10000.0,
34
+ early_exit_loss_weight=0.3,
35
+ max_dropout_prob=0.5,
36
+ inference_exit_layer=None,
37
+ **kwargs
38
+ ):
39
+ self.vocab_size = vocab_size
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+ self.num_key_value_heads = num_key_value_heads
44
+ self.intermediate_size = intermediate_size
45
+ self.max_position_embeddings = max_position_embeddings
46
+ self.rms_norm_eps = rms_norm_eps
47
+ self.rope_theta = rope_theta
48
+ self.early_exit_loss_weight = early_exit_loss_weight
49
+ self.max_dropout_prob = max_dropout_prob
50
+ self.inference_exit_layer = inference_exit_layer
51
+ super().__init__(**kwargs)
52
+
53
+
54
+ class QuadraticLayerDropout(nn.Module):
55
+ """Quadratic layer dropout normalized to sum=1."""
56
+
57
+ def __init__(self, num_layers, max_dropout_prob=0.5):
58
+ super().__init__()
59
+ self.num_layers = num_layers
60
+
61
+ dropout_probs = []
62
+ for i in range(num_layers):
63
+ prob = max_dropout_prob * ((i / max(num_layers - 1, 1)) ** 2)
64
+ dropout_probs.append(prob)
65
+
66
+ total_prob = sum(dropout_probs)
67
+ if total_prob > 0:
68
+ dropout_probs = [p / total_prob for p in dropout_probs]
69
+
70
+ self.dropout_probs = dropout_probs
71
+
72
+ def should_drop_layer(self, layer_idx):
73
+ if not self.training or layer_idx >= self.num_layers - 1:
74
+ return False
75
+ return torch.rand(1).item() < self.dropout_probs[layer_idx]
76
+
77
+
78
+ class RMSNorm(nn.Module):
79
+ def __init__(self, hidden_size, eps=1e-6):
80
+ super().__init__()
81
+ self.weight = nn.Parameter(torch.ones(hidden_size))
82
+ self.variance_epsilon = eps
83
+
84
+ def forward(self, hidden_states):
85
+ input_dtype = hidden_states.dtype
86
+ hidden_states = hidden_states.to(torch.float32)
87
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
88
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
89
+ return self.weight * hidden_states.to(input_dtype)
90
+
91
+
92
+ class RotaryEmbedding(nn.Module):
93
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
94
+ super().__init__()
95
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
96
+ self.register_buffer("inv_freq", inv_freq)
97
+
98
+ def forward(self, x, position_ids):
99
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
100
+ position_ids_expanded = position_ids[:, None, :].float()
101
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
102
+ emb = torch.cat((freqs, freqs), dim=-1)
103
+ return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
104
+
105
+
106
+ def rotate_half(x):
107
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
108
+ return torch.cat((-x2, x1), dim=-1)
109
+
110
+
111
+ def apply_rotary_pos_emb(q, k, cos, sin):
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
+ class BitSkipV2Attention(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ self.hidden_size = config.hidden_size
121
+ self.num_heads = config.num_attention_heads
122
+ self.head_dim = self.hidden_size // self.num_heads
123
+ self.num_key_value_heads = config.num_key_value_heads
124
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
125
+
126
+ self.q_proj = HBitLinear(self.hidden_size, self.num_heads * self.head_dim)
127
+ self.k_proj = HBitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim)
128
+ self.v_proj = HBitLinear(self.hidden_size, self.num_key_value_heads * self.head_dim)
129
+ self.o_proj = HBitLinear(self.hidden_size, self.hidden_size)
130
+
131
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
132
+
133
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
134
+ bsz, q_len, _ = hidden_states.size()
135
+
136
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
137
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
138
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
139
+
140
+ cos, sin = self.rotary_emb(value_states, position_ids)
141
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
142
+
143
+ if past_key_value is not None:
144
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
145
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
146
+
147
+ past_key_value = (key_states, value_states) if use_cache else None
148
+
149
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
150
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
151
+
152
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
153
+ if attention_mask is not None:
154
+ attn_weights = attn_weights + attention_mask
155
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
156
+ attn_output = torch.matmul(attn_weights, value_states)
157
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
158
+ attn_output = self.o_proj(attn_output)
159
+
160
+ return attn_output, None, past_key_value
161
+
162
+
163
+ class BitSkipV2MLP(nn.Module):
164
+ def __init__(self, config):
165
+ super().__init__()
166
+ self.gate_proj = HBitLinear(config.hidden_size, config.intermediate_size)
167
+ self.up_proj = HBitLinear(config.hidden_size, config.intermediate_size)
168
+ self.down_proj = HBitLinear(config.intermediate_size, config.hidden_size)
169
+
170
+ def forward(self, x):
171
+ return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
172
+
173
+
174
+ class BitSkipV2DecoderLayer(nn.Module):
175
+ def __init__(self, config):
176
+ super().__init__()
177
+ self.self_attn = BitSkipV2Attention(config)
178
+ self.mlp = BitSkipV2MLP(config)
179
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
180
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
181
+
182
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ hidden_states, _, present_key_value = self.self_attn(
186
+ hidden_states, attention_mask, position_ids, past_key_value, use_cache
187
+ )
188
+ hidden_states = residual + hidden_states
189
+
190
+ residual = hidden_states
191
+ hidden_states = self.post_attention_layernorm(hidden_states)
192
+ hidden_states = self.mlp(hidden_states)
193
+ hidden_states = residual + hidden_states
194
+
195
+ return (hidden_states,) + ((present_key_value,) if use_cache else ())
196
+
197
+
198
+ class BitSkipV2PreTrainedModel(PreTrainedModel):
199
+ config_class = BitSkipV2EarlyExitConfig
200
+ base_model_prefix = "model"
201
+ supports_gradient_checkpointing = True
202
+
203
+ def _init_weights(self, module):
204
+ if isinstance(module, (nn.Linear, HBitLinear)):
205
+ if hasattr(module, 'weight'):
206
+ module.weight.data.normal_(mean=0.0, std=0.02)
207
+ if hasattr(module, 'bias') and module.bias is not None:
208
+ module.bias.data.zero_()
209
+ elif isinstance(module, nn.Embedding):
210
+ module.weight.data.normal_(mean=0.0, std=0.02)
211
+
212
+
213
+ class BitSkipV2Model(BitSkipV2PreTrainedModel):
214
+ def __init__(self, config):
215
+ super().__init__(config)
216
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
217
+ self.layers = nn.ModuleList([BitSkipV2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
218
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
219
+ self.gradient_checkpointing = False
220
+ self.layer_dropout = QuadraticLayerDropout(config.num_hidden_layers, config.max_dropout_prob)
221
+ self.post_init()
222
+
223
+ def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, output_hidden_states=False, return_all_layer_outputs=False):
224
+ hidden_states = self.embed_tokens(input_ids)
225
+
226
+ if position_ids is None:
227
+ position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
228
+ position_ids = position_ids.unsqueeze(0)
229
+
230
+ next_decoder_cache = () if use_cache else None
231
+ all_layer_hidden_states = []
232
+
233
+ num_layers_to_run = self.config.inference_exit_layer if self.config.inference_exit_layer else len(self.layers)
234
+ num_layers_to_run = min(num_layers_to_run, len(self.layers))
235
+
236
+ for idx in range(num_layers_to_run):
237
+ layer = self.layers[idx]
238
+ past_key_value = past_key_values[idx] if past_key_values else None
239
+
240
+ if self.training and self.layer_dropout.should_drop_layer(idx):
241
+ all_layer_hidden_states.append(hidden_states)
242
+ continue
243
+
244
+ if self.gradient_checkpointing and self.training:
245
+ layer_outputs = self._gradient_checkpointing_func(
246
+ layer.__call__,
247
+ hidden_states,
248
+ attention_mask,
249
+ position_ids,
250
+ past_key_value,
251
+ use_cache,
252
+ )
253
+ else:
254
+ layer_outputs = layer(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
255
+
256
+ hidden_states = layer_outputs[0]
257
+ all_layer_hidden_states.append(hidden_states)
258
+
259
+ if use_cache:
260
+ next_decoder_cache += (layer_outputs[1],)
261
+
262
+ hidden_states = self.norm(hidden_states)
263
+ all_layer_hidden_states.append(hidden_states)
264
+
265
+ if return_all_layer_outputs:
266
+ return hidden_states, next_decoder_cache, all_layer_hidden_states
267
+ else:
268
+ return hidden_states, next_decoder_cache, None
269
+
270
+
271
+ class BitSkipV2ForCausalLMWithEarlyExit(BitSkipV2PreTrainedModel, GenerationMixin):
272
+ _tied_weights_keys = ["lm_head.weight"]
273
+
274
+ def __init__(self, config):
275
+ super().__init__(config)
276
+ self.model = BitSkipV2Model(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+ self.post_init()
280
+
281
+ def get_input_embeddings(self):
282
+ return self.model.embed_tokens
283
+
284
+ def set_input_embeddings(self, value):
285
+ self.model.embed_tokens = value
286
+
287
+ def get_output_embeddings(self):
288
+ return self.lm_head
289
+
290
+ def set_output_embeddings(self, new_embeddings):
291
+ self.lm_head = new_embeddings
292
+
293
+ def compute_early_exit_loss(self, all_layer_hidden_states, labels):
294
+ """Compute early exit loss with layer-proportional weighting."""
295
+ num_layers = len(all_layer_hidden_states)
296
+
297
+ weights = [(i + 1) / num_layers for i in range(num_layers)]
298
+ weight_sum = sum(weights)
299
+ weights = [w / weight_sum for w in weights]
300
+
301
+ total_exit_loss = 0.0
302
+
303
+ for i, hidden_states in enumerate(all_layer_hidden_states):
304
+ logits = self.lm_head(hidden_states)
305
+ shift_logits = logits[..., :-1, :].contiguous()
306
+ shift_labels = labels[..., 1:].contiguous()
307
+
308
+ loss_fct = nn.CrossEntropyLoss()
309
+ layer_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
310
+
311
+ total_exit_loss += weights[i] * layer_loss
312
+
313
+ return total_exit_loss
314
+
315
+ def forward(
316
+ self,
317
+ input_ids=None,
318
+ attention_mask=None,
319
+ position_ids=None,
320
+ past_key_values=None,
321
+ inputs_embeds=None,
322
+ labels=None,
323
+ use_cache=None,
324
+ output_attentions=None,
325
+ output_hidden_states=None,
326
+ return_dict=None,
327
+ ):
328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
329
+ return_all = self.training and labels is not None
330
+
331
+ hidden_states, past_key_values_output, all_layer_hidden_states = self.model(
332
+ input_ids=input_ids,
333
+ attention_mask=attention_mask,
334
+ position_ids=position_ids,
335
+ past_key_values=past_key_values,
336
+ use_cache=use_cache,
337
+ output_hidden_states=output_hidden_states,
338
+ return_all_layer_outputs=return_all,
339
+ )
340
+
341
+ logits = self.lm_head(hidden_states)
342
+ logits = logits.float()
343
+
344
+ loss = None
345
+ if labels is not None:
346
+ shift_logits = logits[..., :-1, :].contiguous()
347
+ shift_labels = labels[..., 1:].contiguous()
348
+ loss_fct = nn.CrossEntropyLoss()
349
+ main_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
350
+
351
+ if all_layer_hidden_states is not None and len(all_layer_hidden_states) > 0:
352
+ early_exit_loss = self.compute_early_exit_loss(all_layer_hidden_states[:-1], labels)
353
+ loss = main_loss + self.config.early_exit_loss_weight * early_exit_loss
354
+ else:
355
+ loss = main_loss
356
+
357
+ if not return_dict:
358
+ output = (logits,) + (past_key_values_output,)
359
+ return (loss,) + output if loss is not None else output
360
+
361
+ return CausalLMOutputWithPast(
362
+ loss=loss,
363
+ logits=logits,
364
+ past_key_values=past_key_values_output,
365
+ hidden_states=None,
366
+ attentions=None,
367
+ )
368
+
369
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
370
+ if past_key_values is not None:
371
+ past_length = past_key_values[0][0].shape[2]
372
+ if input_ids.shape[1] > past_length:
373
+ remove_prefix_length = past_length
374
+ else:
375
+ remove_prefix_length = input_ids.shape[1] - 1
376
+ input_ids = input_ids[:, remove_prefix_length:]
377
+
378
+ position_ids = kwargs.get("position_ids", None)
379
+ if attention_mask is not None and position_ids is None:
380
+ position_ids = attention_mask.long().cumsum(-1) - 1
381
+ position_ids.masked_fill_(attention_mask == 0, 1)
382
+ if past_key_values:
383
+ position_ids = position_ids[:, -input_ids.shape[1] :]
384
+
385
+ if inputs_embeds is not None and past_key_values is None:
386
+ model_inputs = {"inputs_embeds": inputs_embeds}
387
+ else:
388
+ model_inputs = {"input_ids": input_ids}
389
+
390
+ model_inputs.update({
391
+ "position_ids": position_ids,
392
+ "past_key_values": past_key_values,
393
+ "use_cache": kwargs.get("use_cache"),
394
+ "attention_mask": attention_mask,
395
+ })
396
+ return model_inputs
397
+
398
+ @staticmethod
399
+ def _reorder_cache(past_key_values, beam_idx):
400
+ reordered_past = ()
401
+ for layer_past in past_key_values:
402
+ reordered_past += (
403
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
404
+ )
405
+ return reordered_past
406
+
407
+ def set_exit_layer(self, exit_layer):
408
+ self.config.inference_exit_layer = exit_layer
409
+ self.model.config.inference_exit_layer = exit_layer
410
+
411
+
412
+ BitSkipV2EarlyExitConfig.register_for_auto_class()
413
+ BitSkipV2ForCausalLMWithEarlyExit.register_for_auto_class("AutoModelForCausalLM")
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>",
5
+ "unk_token": "<|endoftext|>"
6
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "50256": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ }
12
+ },
13
+ "bos_token": "<|endoftext|>",
14
+ "clean_up_tokenization_spaces": false,
15
+ "eos_token": "<|endoftext|>",
16
+ "model_max_length": 1024,
17
+ "pad_token": "<|endoftext|>",
18
+ "tokenizer_class": "GPT2Tokenizer",
19
+ "unk_token": "<|endoftext|>"
20
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff