yagizdevre commited on
Commit
de97b5e
·
1 Parent(s): 6ff2080
__pycache__/__init__.cpython-312.pyc ADDED
Binary file (314 Bytes). View file
 
__pycache__/attn.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/attn_masks.cpython-312.pyc ADDED
Binary file (8.24 kB). View file
 
__pycache__/attn_mods.cpython-312.pyc ADDED
Binary file (5.77 kB). View file
 
__pycache__/configuration_minitransformer.cpython-312.pyc ADDED
Binary file (2.15 kB). View file
 
__pycache__/convolve.cpython-312.pyc ADDED
Binary file (5.46 kB). View file
 
__pycache__/layers.cpython-312.pyc ADDED
Binary file (2.41 kB). View file
 
__pycache__/mlp.cpython-312.pyc ADDED
Binary file (1.95 kB). View file
 
__pycache__/modeling_minitransformer.cpython-312.pyc ADDED
Binary file (10.9 kB). View file
 
__pycache__/modules.cpython-312.pyc ADDED
Binary file (312 Bytes). View file
 
__pycache__/rotary_emb.cpython-312.pyc ADDED
Binary file (5.78 kB). View file
 
__pycache__/stu.cpython-312.pyc ADDED
Binary file (4.51 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.91 kB). View file
 
config.json CHANGED
@@ -16,14 +16,14 @@
16
  "global_bsz": 524288,
17
  "bsz": 2,
18
  "warmup_steps": 1907,
19
- "eval_peruse_alibiiod": 50,
20
  "save_period": 500,
21
  "max_lr": 3.0e-4,
22
  "min_lr": 3.0e-5,
23
  "max_norm": 1.0,
24
  "dilation": 1,
25
- "fsdp": true,
26
- "ddp": false,
27
  "mixed_precision": true,
28
  "torch_dtype": "bfloat16",
29
  "cpu_offload": false,
@@ -48,4 +48,4 @@
48
  "theta": 10000.0,
49
  "use_alibi": false,
50
  "torch_compile": false
51
- }
 
16
  "global_bsz": 524288,
17
  "bsz": 2,
18
  "warmup_steps": 1907,
19
+ "eval_period": 50,
20
  "save_period": 500,
21
  "max_lr": 3.0e-4,
22
  "min_lr": 3.0e-5,
23
  "max_norm": 1.0,
24
  "dilation": 1,
25
+ "fsdp": false,
26
+ "ddp": true,
27
  "mixed_precision": true,
28
  "torch_dtype": "bfloat16",
29
  "cpu_offload": false,
 
48
  "theta": 10000.0,
49
  "use_alibi": false,
50
  "torch_compile": false
51
+ }
configuration_minitransformer.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  from transformers import PretrainedConfig, AutoConfig
3
-
4
  class MiniTransformerConfig(PretrainedConfig):
5
  model_type = "minitransformer"
6
 
@@ -11,6 +10,7 @@ class MiniTransformerConfig(PretrainedConfig):
11
  num_heads: int = 8,
12
  num_layers: int = 12,
13
  seq_len: int = 8192,
 
14
  window_size: int = 8192,
15
  vocab_size: int = 200064,
16
  mlp_scale: int = 12,
@@ -18,7 +18,7 @@ class MiniTransformerConfig(PretrainedConfig):
18
  dropout: float = 0.0,
19
  softcap: float = 50.0,
20
  theta: float = 10_000.0,
21
- use_alibi: bool = False,
22
  torch_dtype: torch.dtype = torch.bfloat16,
23
  device: torch.device = None,
24
  **kwargs,
@@ -29,6 +29,7 @@ class MiniTransformerConfig(PretrainedConfig):
29
  self.num_heads = num_heads
30
  self.num_layers = num_layers
31
  self.seq_len = seq_len
 
32
  self.window_size = window_size
33
  self.vocab_size = vocab_size
34
  self.hidden_size = dim
@@ -40,5 +41,4 @@ class MiniTransformerConfig(PretrainedConfig):
40
  self.theta = theta
41
  self.use_alibi = use_alibi
42
  self.torch_dtype = torch_dtype
43
- self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') # Store as string
44
-
 
1
  import torch
2
  from transformers import PretrainedConfig, AutoConfig
 
3
  class MiniTransformerConfig(PretrainedConfig):
4
  model_type = "minitransformer"
5
 
 
10
  num_heads: int = 8,
11
  num_layers: int = 12,
12
  seq_len: int = 8192,
13
+ weight_tying: bool = True,
14
  window_size: int = 8192,
15
  vocab_size: int = 200064,
16
  mlp_scale: int = 12,
 
18
  dropout: float = 0.0,
19
  softcap: float = 50.0,
20
  theta: float = 10_000.0,
21
+ use_alibi: bool = False, # Default to RoPE
22
  torch_dtype: torch.dtype = torch.bfloat16,
23
  device: torch.device = None,
24
  **kwargs,
 
29
  self.num_heads = num_heads
30
  self.num_layers = num_layers
31
  self.seq_len = seq_len
32
+ self.weight_tying = weight_tying
33
  self.window_size = window_size
34
  self.vocab_size = vocab_size
35
  self.hidden_size = dim
 
41
  self.theta = theta
42
  self.use_alibi = use_alibi
43
  self.torch_dtype = torch_dtype
44
+ self.device = device
 
modeling_minitransformer.py CHANGED
@@ -1,46 +1,199 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
5
- from transformers import PreTrainedModel
6
- from transformers.modeling_outputs import CausalLMOutput
7
-
8
- from .modules import Attention
9
- from .utils import nearest_power_of_two
10
- from .layers import AttentionLayer
11
  from .configuration_minitransformer import MiniTransformerConfig
12
-
13
- from .attn_masks import causal_mask
14
- from .attn_mods import generate_tanh_softcap
15
- from .rotary_emb import precompute_freqs_cis
16
-
17
  try:
18
- from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm
19
- triton_norm = True
20
  except ImportError as e:
21
  print(
22
- f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
- from torch.nn import RMSNorm
25
- triton_norm = False
26
- # Load the tokenizer
27
 
28
- from transformers import AutoModelForCausalLM, AutoTokenizer
29
- model_name = "Hazan-Lab/Transformer_500M"
30
- tokenizer = AutoTokenizer.from_pretrained(
31
- model_name,
32
- trust_remote_code=True
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  class MiniTransformer(PreTrainedModel):
 
36
  config_class = MiniTransformerConfig
37
 
38
  def __init__(self, config) -> None:
39
- super(MiniTransformer, self).__init__(config)
40
  self.num_layers = config.num_layers
41
  assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
42
  self.head_dim = config.dim // config.num_heads
43
- logit_softcap = generate_tanh_softcap(soft_cap=config.softcap)
44
 
45
  # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
46
  self.register_buffer("freqs_cis", precompute_freqs_cis(
@@ -54,55 +207,36 @@ class MiniTransformer(PreTrainedModel):
54
 
55
  self.layers = nn.ModuleList()
56
  for _ in range(self.num_layers):
57
- layer = AttentionLayer(config, mask_mod=causal_mask, score_mod=logit_softcap)
58
  self.layers.append(layer)
59
 
60
  self.norm = nn.RMSNorm(config.dim)
61
  self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
62
- # self.tok_emb.weight = self.lm_head.weight
63
 
64
- self.std = (config.dim) ** -0.5
 
 
 
65
  self.apply(self._init_weights)
66
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
67
 
68
- def forward(
69
- self,
70
- input_ids: torch.Tensor,
71
- labels: torch.Tensor = None,
72
- **kwargs
73
- ) -> CausalLMOutput:
74
- # Compute embeddings
75
- tok_emb = self.tok_emb(input_ids)
76
 
77
  for layer in self.layers:
78
- tok_emb = layer(tok_emb, self.freqs_cis)
79
-
80
- # Normalize and project to vocabulary
81
- tok_emb = self.norm(tok_emb)
82
- logits = self.lm_head(tok_emb)
83
-
84
- loss = None
85
- if labels is not None:
86
- # Shift so that tokens predict the next token
87
- shift_logits = logits[..., :-1, :].contiguous()
88
- shift_labels = labels[..., 1:].contiguous()
89
- loss_fct = nn.CrossEntropyLoss()
90
- loss = loss_fct(
91
- shift_logits.view(-1, shift_logits.size(-1)),
92
- shift_labels.view(-1)
93
- )
94
-
95
- return CausalLMOutput(
96
- loss=loss,
97
- logits=logits,
98
- )
99
 
100
  def _get_num_params(self):
101
  n_params = sum(p.numel() for p in self.parameters())
 
102
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
103
  n_params -= self.pos_emb.weight.numel()
104
- if self.tok_emb.weight is self.lm_head.weight:
105
- n_params -= self.tok_emb.weight.numel()
106
  return n_params
107
 
108
  def _init_weights(self, module):
@@ -114,105 +248,10 @@ class MiniTransformer(PreTrainedModel):
114
  torch.nn.init.zeros_(module.bias)
115
  elif isinstance(module, nn.Embedding):
116
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
117
-
118
- @staticmethod
119
- def top_k_top_p_filtering(
120
- logits: torch.Tensor,
121
- top_k: int = 50,
122
- top_p: float = 0.95,
123
- filter_value: float = float("-inf"),
124
- ):
125
- """
126
- Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
127
- """
128
- # top_k
129
- if top_k > 0:
130
- top_k = min(top_k, logits.size(-1))
131
- # Remove all logits that are not in the top k
132
- indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
133
- logits[indices_to_remove] = filter_value
134
-
135
- # top_p (nucleus)
136
- if 0 < top_p < 1.0:
137
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
138
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
139
-
140
- # Remove tokens with cumulative probability above the threshold
141
- sorted_indices_to_remove = cumulative_probs > top_p
142
- # Shift the indices to the right to keep also the first token above the threshold
143
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
144
- sorted_indices_to_remove[:, 0] = False
145
-
146
- indices_to_remove = sorted_indices_to_remove.scatter(
147
- dim=1, index=sorted_indices, src=sorted_indices_to_remove
148
- )
149
- logits[indices_to_remove] = filter_value
150
-
151
- return logits
152
-
153
- def generate(
154
- self,
155
- input_ids: torch.LongTensor,
156
- max_new_tokens: int = 50,
157
- temperature: float = 0.5,
158
- top_k: int = 50,
159
- top_p: float = 0.95,
160
- eos_token_id: int = None,
161
- pad_token_id: int = 0,
162
- **kwargs
163
- ):
164
- """
165
- Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
166
-
167
- Args:
168
- input_ids (torch.LongTensor): shape (batch_size, sequence_length).
169
- max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
170
- temperature (float): sampling temperature (>=0).
171
- top_k (int): Top-K sampling cutoff.
172
- top_p (float): Nucleus sampling cutoff.
173
- eos_token_id (int): If set, stop generation when this token is produced.
174
- pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
175
- kwargs: Unused arguments (like num_beams) for compatibility.
176
-
177
- Returns:
178
- torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
179
- """
180
- device = input_ids.device
181
- print("1=====================")
182
- print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
183
- print("1=====================")
184
-
185
- # We'll accumulate new tokens into generated_ids
186
- generated_ids = input_ids.clone()
187
-
188
- for _ in range(max_new_tokens):
189
- # Forward pass to get logits for the last token
190
- outputs = self.forward(generated_ids)
191
- logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
192
-
193
- # Scale logits by temperature
194
- if temperature != 1.0:
195
- logits = logits / temperature
196
-
197
- # Filter logits using top-k and/or top-p
198
- logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
199
-
200
- # Convert to probabilities
201
- probabilities = F.softmax(logits, dim=-1)
202
-
203
- # Sample from the distribution
204
- next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
205
-
206
- # Append next token
207
- generated_ids = torch.cat([generated_ids, next_token], dim=1)
208
-
209
- # If eos_token_id is set and any sample produced it, we optionally could break early
210
- if eos_token_id is not None:
211
- # Check if all sequences in the batch ended
212
- # or if you want to do a more fine-grained approach
213
- if (next_token == eos_token_id).all():
214
- break
215
- print("2=====================")
216
- print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
217
- print("2=====================")
218
- return generated_ids
 
1
+ import json
2
+ import math
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
+ from transformers import PreTrainedModel, PretrainedConfig
 
 
 
 
 
9
  from .configuration_minitransformer import MiniTransformerConfig
 
 
 
 
 
10
  try:
11
+ from flash_attn import flash_attn_func
 
12
  except ImportError as e:
13
  print(
14
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
15
+ )
16
+
17
+
18
+ def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
19
+ # For half the dimensions, build the scale factor:
20
+ freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
21
+ freqs = 1.0 / (theta ** freq_seq)
22
+
23
+ # Outer product with positions
24
+ t = torch.arange(max_seq_len, dtype=torch.float32)
25
+ angles = torch.outer(t, freqs)
26
+
27
+ # Build a complex exponential e^{i * theta}
28
+ freqs_cis = torch.polar(
29
+ torch.ones_like(angles),
30
+ angles
31
+ )
32
+ return freqs_cis
33
+
34
+
35
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
36
+ """
37
+ x is [B, n_heads, seq_len, head_dim_as_complex],
38
+ so we want to broadcast freqs_cis from [max_seq_len, half_dim]
39
+ to [1, 1, seq_len, half_dim].
40
+ """
41
+ seq_len = x.shape[2]
42
+ freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
43
+ return freqs_cis.view(1, 1, seq_len, -1)
44
+
45
+
46
+ def apply_rotary_emb(
47
+ xq: torch.Tensor,
48
+ xk: torch.Tensor,
49
+ freqs_cis: torch.Tensor,
50
+ ) -> tuple[torch.Tensor, torch.Tensor]:
51
+ # Convert real -> complex by grouping last dim in pairs
52
+ # shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
53
+ xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
54
+ xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
55
+
56
+ # Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
57
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
58
+
59
+ # Multiply => apply rotation
60
+ xq_complex = xq_complex * freqs_cis
61
+ xk_complex = xk_complex * freqs_cis
62
+
63
+ # Convert back to real => shape [B, n_heads, seq_len, head_dim]
64
+ xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
65
+ xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
66
+ return xq_out.type_as(xq), xk_out.type_as(xk)
67
+
68
+
69
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
70
+ return (
71
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
72
  )
 
 
 
73
 
74
+
75
+ class Attention(nn.Module):
76
+ def __init__(self, config):
77
+ super(Attention, self).__init__()
78
+ self.dim, self.num_heads = config.dim, config.num_heads
79
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
80
+ self.head_dim = config.dim // config.num_heads
81
+
82
+ self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
83
+ self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
84
+ self.c_proj.SCALE_INIT = 1
85
+
86
+ self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
87
+ self.window_size = config.window_size
88
+ self.softcap = config.softcap
89
+
90
+ self.dropout = config.dropout
91
+ self.resid_dropout = nn.Dropout(self.dropout)
92
+
93
+ def _generate_slopes(self, n: int):
94
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
95
+ return [start * (start**i) for i in range(n)]
96
+
97
+ def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
98
+ # If n_heads is a power of 2, generate slopes directly
99
+ if math.log2(num_heads).is_integer():
100
+ slopes = self._generate_slopes(num_heads)
101
+ else:
102
+ # Get slopes for the nearest power of two
103
+ n = nearest_power_of_two(num_heads, round_up=False)
104
+ slopes_power_of_two = self._generate_slopes(n)
105
+
106
+ # Generate extra slopes
107
+ extra_slopes = self._generate_slopes(2 * n)
108
+ extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
109
+ slopes = slopes_power_of_two + extra_slopes_trunc
110
+ slopes = torch.tensor(slopes, device=torch.device("cuda"))
111
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
112
+ return slopes
113
+
114
+ def forward(
115
+ self,
116
+ x: torch.Tensor = None,
117
+ q: torch.Tensor = None,
118
+ k: torch.Tensor = None,
119
+ v: torch.Tensor = None,
120
+ freqs_cis: torch.Tensor = None,
121
+ ) -> torch.Tensor:
122
+ if x is not None:
123
+ q = k = v = x
124
+ if any(t is None for t in [q, k, v]):
125
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
126
+
127
+ bsz, q_len, dim = q.shape
128
+ _, k_len, _ = k.shape
129
+ _, v_len, _ = v.shape
130
+
131
+ qkv = self.c_attn(x)
132
+ q, k, v = torch.chunk(qkv, 3, dim=2)
133
+
134
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim)
135
+ k = k.view(bsz, k_len, self.num_heads, self.head_dim)
136
+ v = v.view(bsz, v_len, self.num_heads, self.head_dim)
137
+
138
+ if self.alibi_slopes is None: # Use either ALiBi or RoPE
139
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
140
+
141
+ y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
142
+ q=q, k=k, v=v,
143
+ dropout_p=self.dropout if self.training else 0.0,
144
+ causal=True,
145
+ window_size=(self.window_size, 0), # Set to config.seq_len if full attention
146
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
147
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
148
+ )
149
+
150
+ y = y.contiguous().view(bsz, q_len, -1)
151
+ y = self.resid_dropout(self.c_proj(y))
152
+ return y
153
+
154
+
155
+ class AttentionLayer(nn.Module):
156
+ def __init__(self, config) -> None:
157
+ super(AttentionLayer, self).__init__()
158
+ self.attn_norm = nn.RMSNorm(config.dim)
159
+ self.attn = Attention(config=config)
160
+ self.mlp_norm = nn.RMSNorm(config.dim)
161
+ self.mlp = MLP(config)
162
+
163
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
164
+ x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
165
+ x = x + self.mlp(self.mlp_norm(x))
166
+ return x
167
+
168
+ class MLP(nn.Module):
169
+ def __init__(self, config):
170
+ # https://arxiv.org/pdf/2002.05202
171
+ super().__init__()
172
+ self.hidden_size = config.dim
173
+ self.intermediate_size = config.dim * config.mlp_scale
174
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
175
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
176
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
177
+ self.dropout = nn.Dropout(config.dropout)
178
+
179
+ def forward(self, x):
180
+ gate = self.gate_proj(x)
181
+ gate = F.gelu(gate, approximate="tanh")
182
+ up = self.up_proj(x)
183
+ fuse = gate * up
184
+ outputs = self.down_proj(fuse)
185
+ outputs = self.dropout(outputs)
186
+ return outputs
187
 
188
  class MiniTransformer(PreTrainedModel):
189
+
190
  config_class = MiniTransformerConfig
191
 
192
  def __init__(self, config) -> None:
193
+ super(Transformer, self).__init__(config)
194
  self.num_layers = config.num_layers
195
  assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
196
  self.head_dim = config.dim // config.num_heads
 
197
 
198
  # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
199
  self.register_buffer("freqs_cis", precompute_freqs_cis(
 
207
 
208
  self.layers = nn.ModuleList()
209
  for _ in range(self.num_layers):
210
+ layer = AttentionLayer(config=config)
211
  self.layers.append(layer)
212
 
213
  self.norm = nn.RMSNorm(config.dim)
214
  self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
 
215
 
216
+ if config.weight_tying:
217
+ self.tok_emb.weight = self.lm_head.weight
218
+
219
+ self.std = config.dim ** -0.5
220
  self.apply(self._init_weights)
221
  print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
222
 
223
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
224
+ tok_emb = self.tok_emb(x)
225
+ x = self.dropout(tok_emb)
 
 
 
 
 
226
 
227
  for layer in self.layers:
228
+ x = layer(x, self.freqs_cis)
229
+
230
+ y_hat = self.lm_head(self.norm(x))
231
+
232
+ return y_hat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def _get_num_params(self):
235
  n_params = sum(p.numel() for p in self.parameters())
236
+
237
  if hasattr(self, "pos_emb") and self.pos_emb is not None:
238
  n_params -= self.pos_emb.weight.numel()
239
+
 
240
  return n_params
241
 
242
  def _init_weights(self, module):
 
248
  torch.nn.init.zeros_(module.bias)
249
  elif isinstance(module, nn.Embedding):
250
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
251
+ elif isinstance(module, Attention):
252
+ torch.nn.init.xavier_normal_(module.c_attn.weight)
253
+ torch.nn.init.xavier_normal_(module.c_proj.weight)
254
+ if module.c_attn.bias is not None:
255
+ torch.nn.init.zeros_(module.c_attn.bias)
256
+ if module.c_proj.bias is not None:
257
+ torch.nn.init.zeros_(module.c_proj.bias)