Hoffman37 commited on
Commit
32055cf
·
verified ·
1 Parent(s): d81212f

End of training

Browse files
Files changed (4) hide show
  1. README.md +56 -0
  2. bitllama_layers.py +163 -0
  3. generation_config.json +6 -0
  4. model.safetensors +1 -1
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - generated_from_trainer
4
+ model-index:
5
+ - name: bitllama-Llama2
6
+ results: []
7
+ ---
8
+
9
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
10
+ should probably proofread and complete it, then remove this comment. -->
11
+
12
+ # bitllama-Llama2
13
+
14
+ This model is a fine-tuned version of [](https://huggingface.co/) on an unknown dataset.
15
+ It achieves the following results on the evaluation set:
16
+ - Loss: 3.7504
17
+
18
+ ## Model description
19
+
20
+ More information needed
21
+
22
+ ## Intended uses & limitations
23
+
24
+ More information needed
25
+
26
+ ## Training and evaluation data
27
+
28
+ More information needed
29
+
30
+ ## Training procedure
31
+
32
+ ### Training hyperparameters
33
+
34
+ The following hyperparameters were used during training:
35
+ - learning_rate: 0.0024
36
+ - train_batch_size: 96
37
+ - eval_batch_size: 96
38
+ - seed: 42
39
+ - optimizer: Adam with betas=(0.9,0.95) and epsilon=1e-08
40
+ - lr_scheduler_type: linear
41
+ - lr_scheduler_warmup_steps: 5000
42
+ - num_epochs: 1
43
+
44
+ ### Training results
45
+
46
+ | Training Loss | Epoch | Step | Validation Loss |
47
+ |:-------------:|:-----:|:----:|:---------------:|
48
+ | 4.6791 | 0.74 | 2000 | 3.7504 |
49
+
50
+
51
+ ### Framework versions
52
+
53
+ - Transformers 4.38.2
54
+ - Pytorch 2.8.0+cu126
55
+ - Datasets 4.0.0
56
+ - Tokenizers 0.15.2
bitllama_layers.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional,Tuple
3
+ from transformers.models.llama.modeling_llama import (
4
+ LlamaConfig,
5
+ LlamaModel,
6
+ LlamaForCausalLM,
7
+ LlamaAttention,
8
+ LlamaFlashAttention2,
9
+ LlamaSdpaAttention,
10
+ LlamaMLP,
11
+ LlamaDecoderLayer
12
+ )
13
+ from bitnet import BitLinear,BitLinear158b
14
+ import torch
15
+ from torch import nn
16
+
17
+ class BitLlamaConfig(LlamaConfig):
18
+ model_type="bit_llama"
19
+
20
+ def __init__(self,bitnet_type="1.58b",bits=8,**kwargs):
21
+ super().__init__(**kwargs)
22
+ self.bitnet_type=bitnet_type
23
+ if self.bitnet_type not in ["1.58b","1b"]:
24
+ raise ValueError("bitnet_type must be either '1.58b' or '1b'." )
25
+ self.bits=bits
26
+
27
+ class BitLlamaMLP(LlamaMLP):
28
+ def __init__(self,config:BitLlamaConfig):
29
+ super().__init__(config)
30
+ if config.bitnet_type=="1b":
31
+ self.gate_proj=BitLinear(self.hidden_size,self.intermediate_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=False)
32
+ self.up_proj=BitLinear(self.hidden_size,self.intermediate_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
33
+ self.down_proj=BitLinear(self.intermediate_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
34
+ elif config.bitnet_type=="1.58b":
35
+ self.gate_proj=BitLinear158b(self.hidden_size,self.intermediate_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
36
+ self.up_proj=BitLinear158b(self.hidden_size,self.intermediate_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
37
+ self.down_proj=BitLinear158b(self.intermediate_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
38
+ else:
39
+ raise ValueError("bitnet_type must be either '1.58b' or '1b")
40
+
41
+ class BitLlamaAttention(LlamaAttention):
42
+ def __init__(self,config:BitLlamaConfig,layer_idx:Optional[int]=None):
43
+ super().__init__(config,layer_idx)
44
+ if config.bitnet_type=="1b":
45
+ self.q_proj=BitLinear(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
46
+ self.k_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
47
+ self.v_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
48
+ self.o_proj=BitLinear(self.hidden_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
49
+ elif config.bitnet_type=="1.58b":
50
+ self.q_proj=BitLinear158b(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
51
+ self.k_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
52
+ self.v_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps)
53
+ self.o_proj=BitLinear158b(self.hidden_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps)
54
+ else:
55
+ raise ValueError("bitnet_type must be either '1.58b' or '1b'.")
56
+
57
+ class BitLlamaFlashAttention2(LlamaFlashAttention2):
58
+ def __init__(self,config:BitLlamaConfig,layer_idx:Optional[int]=None):
59
+ super().__init__(config,layer_idx)
60
+ if config.bitnet_type=="1b":
61
+ self.q_proj=BitLinear(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
62
+ self.k_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
63
+ self.v_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
64
+ self.o_proj=BitLinear(self.hidden_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
65
+ elif config.bitnet_type=="1.58b":
66
+ self.q_proj=BitLinear158b(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
67
+ self.k_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
68
+ self.v_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
69
+ self.o_proj=BitLinear158b(self.hidden_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
70
+ else:
71
+ raise ValueError("bitnet_type must be either '1.58b' or '1b'.")
72
+
73
+ class BitLlamaSpdaAttention(LlamaSdpaAttention):
74
+ def __init__(self,config:BitLlamaConfig,layer_idx:Optional[int]=None):
75
+ super().__init__(config,layer_idx)
76
+ if config.bitnet_type=="1b":
77
+ self.q_proj=BitLinear(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
78
+ self.k_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
79
+ self.v_proj=BitLinear(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_config=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
80
+ self.o_proj=BitLinear(self.hidden_size,self.hidden_size,bias=False,rms_norm_esp=config.rms_norm_eps,bits=config.bits,flg_before_linear=True)
81
+ elif config.bitnet_type=="1.58b":
82
+ self.q_proj=BitLinear158b(self.hidden_size,self.num_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
83
+ self.k_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
84
+ self.v_proj=BitLinear158b(self.hidden_size,self.num_key_value_heads*self.head_dim,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
85
+ self.o_proj=BitLinear158b(self.hidden_size,self.hidden_size,bias=False,rms_norm_eps=config.rms_norm_eps,bits=config.bits)
86
+ else:
87
+ raise ValueError("bitnet_type must be either '1.58b' or '1b'.")
88
+
89
+ BITLLAMA_ATTENTION_CLASSES={
90
+ "eager":BitLlamaAttention,
91
+ "flash_attention_2":BitLlamaFlashAttention2,
92
+ "sdpa":BitLlamaSpdaAttention,
93
+ }
94
+
95
+ class BitLlamaDecoderLayer(LlamaDecoderLayer):
96
+ def __init__(self,config:BitLlamaConfig,layer_idx:int):
97
+ super().__init__(config,layer_idx)
98
+ self.self_attn=BITLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config,layer_idx=layer_idx)
99
+ self.mlp=BitLlamaMLP(config)
100
+ del self.input_layernorm
101
+ del self.post_attention_layernorm
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states:torch.Tensor,
106
+ attention_mask:Optional[torch.Tensor]=None,
107
+ position_ids:Optional[torch.LongTensor]=None,
108
+ past_key_value:Optional[Tuple[torch.Tensor]]=None,
109
+ output_attentions:Optional[bool]=False,
110
+ use_cache:Optional[bool]=False,
111
+ cache_position:Optional[torch.LongTensor]=None,
112
+ **kwargs,
113
+ )->Tuple[torch.FloatTensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
114
+
115
+ if "padding_mask" in kwargs:
116
+ warnings.warn(
117
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
118
+ )
119
+
120
+ residual=hidden_states
121
+
122
+ hidden_states,self_attn_weight,present_key_value=self.self_attn(
123
+ hidden_states=hidden_states,
124
+ attention_mask=attention_mask,
125
+ position_ids=position_ids,
126
+ past_key_value=past_key_value,
127
+ output_attentions=output_attentions,
128
+ use_cache=use_cache,
129
+ cache_position=cache_position,
130
+ **kwargs,
131
+ )
132
+ hidden_states=residual+hidden_states
133
+
134
+ residual=hidden_states
135
+ hidden_states=self.mlp(hidden_states)
136
+ hidden_states=residual+hidden_states
137
+
138
+ outputs=(hidden_states,)
139
+
140
+ if output_attentions:
141
+ outputs+=(self_attn_weight,)
142
+
143
+ if use_cache:
144
+ outputs+=(present_key_value,)
145
+
146
+ return outputs
147
+
148
+ class BitLlamaModel(LlamaModel):
149
+ config_class=BitLlamaConfig
150
+
151
+ def __init__(self,config:BitLlamaConfig):
152
+ super().__init__(config)
153
+ self.layers=nn.ModuleList(
154
+ [BitLlamaDecoderLayer(config,layer_idx) for layer_idx in range(config.num_hidden_layers)]
155
+ )
156
+
157
+ class BitLlamaForCausalLM(LlamaForCausalLM):
158
+ config_class=BitLlamaConfig
159
+
160
+ def __init__(self,config:BitLlamaConfig):
161
+ super().__init__(config)
162
+ self.model=BitLlamaModel(config)
163
+ self.lm_head=BitLinear(config.hidden_size,config.vocab_size,bias=False,bits=config.bits,flg_before_linear=True)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.38.2"
6
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc186dae20a199ce00a1d07bc4da59dba3743cfb308d034721a3a2b606153e68
3
  size 510960712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c269b518dbf4a5c57251f3ef15b172074e8a98174de85da677c454cf3b4fd152
3
  size 510960712