File size: 18,328 Bytes
edc9020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
"""
BitSkip v3: v1 architecture WITH Hadamard transform
- 8-bit activations (like v1)
- Hadamard transform (like v2)
- Tests if Hadamard improves 8-bit quantization
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast


def hadamard_transform(x):
    """Fast Walsh-Hadamard Transform."""
    orig_shape = x.shape
    n = x.shape[-1]
    
    assert n & (n - 1) == 0, f"Dimension must be power of 2, got {n}"
    
    x = x.reshape(-1, n)
    
    h = 1
    while h < n:
        x = x.reshape(-1, n // (2 * h), 2, h)
        x_even = x[:, :, 0, :]
        x_odd = x[:, :, 1, :]
        
        x[:, :, 0, :] = x_even + x_odd
        x[:, :, 1, :] = x_even - x_odd
        
        x = x.reshape(-1, n)
        h *= 2
    
    x = x / math.sqrt(n)
    return x.reshape(orig_shape)


class BitLinearV3(nn.Module):
    """
    BitLinear with Hadamard: 8-bit activations + Hadamard transform.
    Combination of v1's 8-bit with v2's Hadamard.
    """
    
    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        
        assert in_features & (in_features - 1) == 0, f"in_features must be power of 2, got {in_features}"
        assert out_features & (out_features - 1) == 0, f"out_features must be power of 2, got {out_features}"
        
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.02)
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        self.norm = nn.LayerNorm(in_features)
    
    def forward(self, x):
        # 1. LayerNorm
        x = self.norm(x)
        
        # 2. Hadamard transform
        x = hadamard_transform(x)
        
        # 3. 8-bit quantization (more stable than v2's 4-bit)
        x_scale = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5)
        x_quant = (x / x_scale * 127).round().clamp(-128, 127)
        x_quant = x_quant / 127 * x_scale
        
        if self.training:
            x_quant = x + (x_quant - x).detach()
        
        # 4. Ternary weights
        w_scale = self.weight.abs().mean().clamp(min=1e-5)
        w_quant = torch.zeros_like(self.weight)
        w_quant[self.weight > 0.5 * w_scale] = 1.0
        w_quant[self.weight < -0.5 * w_scale] = -1.0
        w_quant = w_quant * w_scale
        
        if self.training:
            w_quant = self.weight + (w_quant - self.weight).detach()
        
        # 5. Linear
        output = F.linear(x_quant, w_quant, self.bias)
        
        # 6. Inverse Hadamard
        output = hadamard_transform(output)
        
        return output


class BitSkipV3Config(PretrainedConfig):
    model_type = "bitskip_v3"
    
    def __init__(
        self,
        vocab_size=50257,
        hidden_size=2048,
        num_hidden_layers=24,
        num_attention_heads=32,
        num_key_value_heads=8,
        intermediate_size=4096,
        max_position_embeddings=2048,
        rms_norm_eps=1e-5,
        rope_theta=10000.0,
        early_exit_loss_weight=0.3,
        max_dropout_prob=0.5,
        inference_exit_layer=None,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.early_exit_loss_weight = early_exit_loss_weight
        self.max_dropout_prob = max_dropout_prob
        self.inference_exit_layer = inference_exit_layer
        super().__init__(**kwargs)


class QuadraticLayerDropout(nn.Module):
    def __init__(self, num_layers, max_dropout_prob=0.5):
        super().__init__()
        self.num_layers = num_layers
        
        dropout_probs = []
        for i in range(num_layers):
            prob = max_dropout_prob * ((i / max(num_layers - 1, 1)) ** 2)
            dropout_probs.append(prob)
        
        total_prob = sum(dropout_probs)
        if total_prob > 0:
            dropout_probs = [p / total_prob for p in dropout_probs]
        
        self.dropout_probs = dropout_probs
    
    def should_drop_layer(self, layer_idx):
        if not self.training or layer_idx >= self.num_layers - 1:
            return False
        return torch.rand(1).item() < self.dropout_probs[layer_idx]


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos().to(x.dtype), emb.sin().to(x.dtype)


def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class BitSkipV3Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        
        self.q_proj = BitLinearV3(self.hidden_size, self.num_heads * self.head_dim)
        self.k_proj = BitLinearV3(self.hidden_size, self.num_key_value_heads * self.head_dim)
        self.v_proj = BitLinearV3(self.hidden_size, self.num_key_value_heads * self.head_dim)
        self.o_proj = BitLinearV3(self.hidden_size, self.hidden_size)
        
        self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
        bsz, q_len, _ = hidden_states.size()
        
        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
        value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value


class BitSkipV3MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = BitLinearV3(config.hidden_size, config.intermediate_size)
        self.up_proj = BitLinearV3(config.hidden_size, config.intermediate_size)
        self.down_proj = BitLinearV3(config.intermediate_size, config.hidden_size)

    def forward(self, x):
        return self.down_proj(nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))


class BitSkipV3DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn = BitSkipV3Attention(config)
        self.mlp = BitSkipV3MLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, _, present_key_value = self.self_attn(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
        hidden_states = residual + hidden_states
        
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        
        return (hidden_states,) + ((present_key_value,) if use_cache else ())


class BitSkipV3PreTrainedModel(PreTrainedModel):
    config_class = BitSkipV3Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, BitLinearV3)):
            if hasattr(module, 'weight'):
                module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)


class BitSkipV3Model(BitSkipV3PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([BitSkipV3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False
        self.layer_dropout = QuadraticLayerDropout(config.num_hidden_layers, config.max_dropout_prob)
        self.post_init()

    def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, output_hidden_states=False, return_all_layer_outputs=False):
        hidden_states = self.embed_tokens(input_ids)
        
        if position_ids is None:
            position_ids = torch.arange(input_ids.shape[1], dtype=torch.long, device=input_ids.device)
            position_ids = position_ids.unsqueeze(0)
        
        next_decoder_cache = () if use_cache else None
        all_layer_hidden_states = []
        
        num_layers_to_run = self.config.inference_exit_layer if self.config.inference_exit_layer else len(self.layers)
        num_layers_to_run = min(num_layers_to_run, len(self.layers))
        
        for idx in range(num_layers_to_run):
            layer = self.layers[idx]
            past_key_value = past_key_values[idx] if past_key_values else None
            
            if self.training and self.layer_dropout.should_drop_layer(idx):
                all_layer_hidden_states.append(hidden_states)
                continue
            
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(layer.__call__, hidden_states, attention_mask, position_ids, past_key_value, use_cache)
            else:
                layer_outputs = layer(hidden_states, attention_mask, position_ids, past_key_value, use_cache)
            
            hidden_states = layer_outputs[0]
            all_layer_hidden_states.append(hidden_states)
            
            if use_cache:
                next_decoder_cache += (layer_outputs[1],)
        
        hidden_states = self.norm(hidden_states)
        all_layer_hidden_states.append(hidden_states)
        
        if return_all_layer_outputs:
            return hidden_states, next_decoder_cache, all_layer_hidden_states
        else:
            return hidden_states, next_decoder_cache, None


class BitSkipV3ForCausalLM(BitSkipV3PreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]
    
    def __init__(self, config):
        super().__init__(config)
        self.model = BitSkipV3Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def compute_early_exit_loss(self, all_layer_hidden_states, labels):
        num_layers = len(all_layer_hidden_states)
        weights = [(i + 1) / num_layers for i in range(num_layers)]
        weight_sum = sum(weights)
        weights = [w / weight_sum for w in weights]
        
        total_exit_loss = 0.0
        
        for i, hidden_states in enumerate(all_layer_hidden_states):
            logits = self.lm_head(hidden_states)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            layer_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
            total_exit_loss += weights[i] * layer_loss
        
        return total_exit_loss

    def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        return_all = self.training and labels is not None
        
        hidden_states, past_key_values_output, all_layer_hidden_states = self.model(
            input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,
            past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states,
            return_all_layer_outputs=return_all,
        )

        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            main_loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
            
            if all_layer_hidden_states is not None and len(all_layer_hidden_states) > 0:
                early_exit_loss = self.compute_early_exit_loss(all_layer_hidden_states[:-1], labels)
                loss = main_loss + self.config.early_exit_loss_weight * early_exit_loss
            else:
                loss = main_loss

        if not return_dict:
            output = (logits,) + (past_key_values_output,)
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values_output, hidden_states=None, attentions=None)

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                remove_prefix_length = input_ids.shape[1] - 1
            input_ids = input_ids[:, remove_prefix_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update({"position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask})
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
        return reordered_past
    
    def set_exit_layer(self, exit_layer):
        self.config.inference_exit_layer = exit_layer
        self.model.config.inference_exit_layer = exit_layer


BitSkipV3Config.register_for_auto_class()
BitSkipV3ForCausalLM.register_for_auto_class("AutoModelForCausalLM")