Marmik commited on
Commit
97a6b05
·
verified ·
1 Parent(s): 585483e

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +5 -1
  2. generation_config.json +1 -1
  3. modeling_tiny_gpt.py +501 -0
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "TinyGPTForCausalLM"
4
  ],
 
 
 
 
5
  "attn_dropout": 0.0,
6
  "attn_eps": 1e-06,
7
  "d_head": 64,
@@ -15,6 +19,6 @@
15
  "norm_eps": 1e-06,
16
  "top_k": null,
17
  "torch_dtype": "float32",
18
- "transformers_version": "4.47.1",
19
  "vocab_size": 50257
20
  }
 
2
  "architectures": [
3
  "TinyGPTForCausalLM"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_tiny_gpt.TinyGPTConfig",
7
+ "AutoModelForCausalLM": "modeling_tiny_gpt.TinyGPTForCausalLM"
8
+ },
9
  "attn_dropout": 0.0,
10
  "attn_eps": 1e-06,
11
  "d_head": 64,
 
19
  "norm_eps": 1e-06,
20
  "top_k": null,
21
  "torch_dtype": "float32",
22
+ "transformers_version": "4.53.2",
23
  "vocab_size": 50257
24
  }
generation_config.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.47.1"
4
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.53.2"
4
  }
modeling_tiny_gpt.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
2
+ from transformers.modeling_outputs import CausalLMOutputWithPast
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import math
9
+ import copy
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
13
+
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ """config for tiny gpt inference"""
18
+ vocab_size:int=50_257 # 50_256
19
+ d_model: int = 768 #embedding size # 768
20
+ d_head: int = 64 #head size
21
+ n_heads:int = 12 #number of heads # 12
22
+ n_layers:int = 5 #number of layers # 5
23
+ max_seq_len:int = 1024 #maximum sequence length # 1024
24
+ attn_eps:float = 1e-6
25
+ ffn_eps:float = 1e-6
26
+ norm_eps:float = 1e-6
27
+ attn_dropout:float = 0.0 #attention dropout (disabled for inference)
28
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+
30
+
31
+ @dataclass
32
+ class ModelArgs:
33
+ vocab_size:int=50_256 # 50_256
34
+ d_model: int = 768 #embedding size # 768
35
+ d_head: int = 64 #head size
36
+ n_heads:int = 12 #number of heads # 12
37
+ n_kv_heads:int = 8 #number of key-value heads # 8
38
+ n_layers:int = 5 #number of layers # 5
39
+ train_epochs:int = 1 #number of epochs # 1
40
+ batch_size:int = 256 #batch size # 256
41
+ val_epochs:int = 1 #number of validation epochs # 1
42
+ window_size:int = 128 #window size # 128
43
+ seq_len:int = 512 #sequence length # 512
44
+ max_seq_len:int = 1024 #maximum sequence length # 1024
45
+ max_lr:float = 1e-3 #maximum learning rate
46
+ val_steps:int = 300 #validation steps # 250-500 depending on total_train_steps
47
+ save_steps:int = 1000 #save steps # 1000, total_train_steps = total_toks // (batch_size * seq_len)
48
+ # do not change
49
+ clip:int = 1 #gradient clipping
50
+ attn_dropout:float = 0.1 #attention dropout
51
+ dropout:float = 0.1 #dropout
52
+ beta1:float = 0.9 #beta1
53
+ beta2:float = 0.999 #beta2
54
+ device:str = 'cuda' if torch.cuda.is_available() else 'cpu'
55
+ wandb_project:str = 'dense'
56
+ norm_eps:float = 1e-6
57
+ attn_eps:float = 1e-6
58
+ ffn_eps:float = 1e-6
59
+
60
+
61
+ class TinyGPTConfig(PretrainedConfig):
62
+ model_type = "tiny_gpt"
63
+ def __init__(self,
64
+ vocab_size = ModelConfig.vocab_size,
65
+ d_model = ModelConfig.d_model,
66
+ d_head = ModelConfig.d_head,
67
+ n_heads = ModelConfig.n_heads,
68
+ n_layers = ModelConfig.n_layers,
69
+ max_seq_len = ModelConfig.max_seq_len,
70
+ norm_eps = ModelConfig.norm_eps,
71
+ attn_eps = ModelConfig.attn_eps,
72
+ ffn_eps = ModelConfig.ffn_eps,
73
+ attn_dropout = ModelConfig.attn_dropout,
74
+ device = ModelConfig.device,
75
+ **kwargs
76
+ ):
77
+ kwargs["auto_map"] = {
78
+ "AutoConfig": "modeling_tiny_gpt.TinyGPTConfig",
79
+ "AutoModelForCausalLM": "modeling_tiny_gpt.TinyGPTForCausalLM"
80
+ }
81
+ super().__init__(**kwargs)
82
+ self.vocab_size = vocab_size
83
+ self.d_model = d_model
84
+ self.d_head = d_head
85
+ self.n_heads = n_heads
86
+ self.n_layers = n_layers
87
+ self.max_seq_len = max_seq_len
88
+ self.norm_eps = norm_eps
89
+ self.attn_eps = attn_eps
90
+ self.ffn_eps = ffn_eps
91
+ self.attn_dropout = attn_dropout
92
+ self.device = device
93
+
94
+
95
+
96
+
97
+
98
+
99
+ class RMSNorm(nn.Module):
100
+ def __init__(self,dim:int,eps:float=1e-6):
101
+ """
102
+ Initializes the RMSNorm module.
103
+
104
+ Args:
105
+ dim (int): The dimensionality of the input feature space.
106
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
107
+ """
108
+ super().__init__()
109
+ self.eps=eps
110
+ self.w=nn.Parameter(torch.ones(dim))
111
+
112
+ def norm(self,x:torch.Tensor):
113
+ """
114
+ Computes the root mean square normalization of the input tensor.
115
+
116
+ Args:
117
+ x (torch.Tensor): The input tensor.
118
+
119
+ Returns:
120
+ torch.Tensor: The normalized tensor.
121
+ """
122
+ return x * torch.rsqrt(torch.mean(x**2,-1, keepdim=True) + self.eps)
123
+ def forward(self,x:torch.Tensor):
124
+ """
125
+ Forward pass of the RMSNorm module.
126
+
127
+ Args:
128
+ x (torch.Tensor): The input tensor.
129
+
130
+ Returns:
131
+ torch.Tensor: The normalized tensor.
132
+ """
133
+ return self.w * self.norm(x.float()).type_as(x)
134
+
135
+
136
+
137
+
138
+ #----Rotary Embeddings---
139
+
140
+ def precompute_theta_pos_frequencies(d_head:int,seq_len:int,device:str,theta:float=10000.0):
141
+ """
142
+ Precomputes the position frequencies for Rotary Position Embeddings.
143
+
144
+ Args:
145
+ d_head (int): The number of dimensions in the attention head.
146
+ seq_len (int): The sequence length of the input sequence.
147
+ device (str): The device on which to create the tensor.
148
+ theta (float, optional): The base for the exponential decay. Defaults to 10000.0.
149
+
150
+ Returns:
151
+ torch.Tensor: A tensor of shape (seq_len, d_head/2) containing the complex position frequencies.
152
+ """
153
+ assert d_head%2==0,"d_head must be even"
154
+ #theta_i=1000^-2(i-1)/d_head for i [1,2...d_head/2]
155
+ theta_nr=torch.arange(0,d_head,2,device=device)
156
+ theta=1.0/(theta**(theta_nr/d_head)).to(device)
157
+
158
+ m=torch.arange(seq_len,device=device)
159
+ m_theta=torch.outer(m,theta).float()
160
+ freq_complex=torch.polar(torch.ones_like(m_theta),m_theta)
161
+
162
+ return freq_complex #(seq_len,d_head/2)
163
+
164
+
165
+ def apply_rotary_embeddings(x:torch.Tensor,freq_complex:torch.Tensor,device:str):
166
+ """
167
+ Applies Rotary Position Embeddings to the input tensor.
168
+
169
+ Args:
170
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_head).
171
+ freq_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
172
+
173
+ Returns:
174
+ torch.Tensor: The tensor after applying Rotary Position Embeddings.
175
+ """
176
+ x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2
177
+
178
+ freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
179
+
180
+ x_rotated=x_complex * freq_complex #(N,seq_len,h,head_dim/2)
181
+ x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
182
+ x_out=x_out.reshape(*x.shape)
183
+
184
+ return x_out.type_as(x).to(device)
185
+
186
+
187
+
188
+ class SubLayerConnection(nn.Module):
189
+ def __init__(self,size,dropout):
190
+ """
191
+ Initializes the SubLayerConnection module.
192
+
193
+ Args:
194
+ size (int): The size of the input for the layer normalization.
195
+ dropout (float): The dropout rate to be applied after the sublayer.
196
+ """
197
+ super(SubLayerConnection,self).__init__()
198
+ self.norm=nn.LayerNorm(size)
199
+ self.dropout=nn.Dropout(dropout)
200
+
201
+ def forward(self,x,sublayer):
202
+ """
203
+ Computes the output of the SubLayerConnection module.
204
+
205
+ Args:
206
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
207
+ sublayer (nn.Module): The sublayer module to be applied to the input tensor.
208
+
209
+ Returns:
210
+ torch.Tensor: The output tensor of shape (batch_size, seq_len, d_model).
211
+ """
212
+
213
+ return x + self.dropout(sublayer(self.norm(x)))
214
+
215
+
216
+ def clones(module,N):
217
+ """
218
+ Creates a list of N copies of the given nn.Module.
219
+
220
+ Args:
221
+ nn.Module: The nn.Module to be cloned.
222
+ N (int): The number of copies to be made.
223
+
224
+ Returns:
225
+ nn.ModuleList: A list of N identical nn.Module objects.
226
+ """
227
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
228
+
229
+
230
+
231
+
232
+ class SimpleMultiHeadAttention(nn.Module):
233
+ """Simple multi-head attention without GQA, sliding window, or KV cache"""
234
+
235
+ def __init__(self, dim: int, num_heads: int, device, dropout: float = 0.0, bias: bool = False):
236
+ """
237
+ Initialize the SimpleMultiHeadAttention module.
238
+
239
+ Args:
240
+ dim (int): The dimensionality of the input and output features.
241
+ num_heads (int): The number of attention heads.
242
+ device: The device to use (cpu or cuda).
243
+ dropout (float, optional): Dropout rate. Defaults to 0.0.
244
+ bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
245
+ """
246
+ super().__init__()
247
+ assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
248
+
249
+ self.dim = dim
250
+ self.num_heads = num_heads
251
+ self.head_dim = dim // num_heads
252
+ self.device = device
253
+ self.dropout = dropout
254
+
255
+ # Combined projection for queries, keys, and values
256
+ self.c_attn = nn.Linear(dim, 3 * dim, bias=bias)
257
+ # Output projection
258
+ self.c_proj = nn.Linear(dim, dim, bias=bias)
259
+
260
+ # Dropout layers
261
+ self.attn_dropout = nn.Dropout(dropout)
262
+ self.resid_dropout = nn.Dropout(dropout)
263
+
264
+ # Use flash attention if available
265
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
266
+ if not self.flash:
267
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
268
+
269
+ def forward(self, x: torch.Tensor, freqs_complex: torch.Tensor = None, start_pos: int = 0):
270
+ """
271
+ Compute multi-head attention.
272
+
273
+ Args:
274
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
275
+ freqs_complex (torch.Tensor, optional): Complex position frequencies for RoPE. Defaults to None.
276
+ start_pos (int, optional): Starting position (unused in simple attention). Defaults to 0.
277
+
278
+ Returns:
279
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, dim).
280
+ """
281
+ batch_size, seq_len, _ = x.shape
282
+
283
+ # Calculate query, key, values for all heads in batch
284
+ q, k, v = self.c_attn(x).split(self.dim, dim=2)
285
+
286
+ # Reshape and transpose for multi-head attention
287
+ # (batch_size, seq_len, num_heads, head_dim) -> (batch_size, num_heads, seq_len, head_dim)
288
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
289
+ k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
290
+ v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
291
+
292
+ # Apply rotary embeddings if provided
293
+ if freqs_complex is not None:
294
+ # Note: apply_rotary_embeddings expects (batch, seq_len, num_heads, head_dim)
295
+ q_rotary = q.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
296
+ k_rotary = k.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
297
+
298
+ q_rotary = apply_rotary_embeddings(q_rotary, freqs_complex, device=self.device)
299
+ k_rotary = apply_rotary_embeddings(k_rotary, freq_complex=freqs_complex, device=self.device)
300
+
301
+ q = q_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
302
+ k = k_rotary.transpose(1, 2) # Back to (batch_size, num_heads, seq_len, head_dim)
303
+
304
+ # Compute attention
305
+ if self.flash:
306
+ # Use flash attention for efficiency
307
+ y = torch.nn.functional.scaled_dot_product_attention(
308
+ q, k, v,
309
+ attn_mask=None,
310
+ dropout_p=self.dropout if self.training else 0,
311
+ is_causal=True
312
+ )
313
+ else:
314
+ # Manual implementation of attention
315
+ # Compute attention scores
316
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
317
+
318
+ # Apply causal mask
319
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=self.device))
320
+ causal_mask = causal_mask.view(1, 1, seq_len, seq_len)
321
+ attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
322
+
323
+ # Apply softmax
324
+ attn_weights = F.softmax(attn_scores, dim=-1)
325
+ attn_weights = self.attn_dropout(attn_weights)
326
+
327
+ # Apply attention to values
328
+ y = torch.matmul(attn_weights, v)
329
+
330
+ # Reshape back to (batch_size, seq_len, dim)
331
+ y = y.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
332
+
333
+ # Output projection
334
+ y = self.resid_dropout(self.c_proj(y))
335
+
336
+ return y
337
+
338
+ def reset_cache(self):
339
+ """Reset cache (no-op for simple attention)"""
340
+ pass
341
+
342
+
343
+
344
+
345
+ class SwiGLUFFN(nn.Module):
346
+ def __init__(self,input_dim:int,hidden_dim:int):
347
+ """
348
+ Initializes the SwiGLUFFN module.
349
+
350
+ Args:
351
+ input_dim (int): The dimensionality of the input features.
352
+ hidden_dim (int): The dimensionality of the hidden layer.
353
+
354
+ Initializes three linear layers:
355
+ - `w_1`: Projects input features to the hidden dimension.
356
+ - `w_2`: Projects input features to the hidden dimension using a separate path.
357
+ - `out`: Projects the transformed hidden representation back to the input dimension.
358
+ """
359
+ super().__init__()
360
+ self.w_1=nn.Linear(input_dim,hidden_dim)
361
+ self.w_2=nn.Linear(input_dim,hidden_dim)
362
+ self.out=nn.Linear(hidden_dim,input_dim)
363
+ def forward(self,x:torch.Tensor):
364
+ """
365
+ Computes the output of the SwiGLUFFN module.
366
+ """
367
+ return self.out(self.w_1(x) * F.silu(self.w_2(x)))
368
+
369
+
370
+
371
+ class layer(nn.Module):
372
+ def __init__(self, d_model: int, n_heads: int, device, attn_eps: float, dropout: float, ffn_eps: float = 1e-6):
373
+ """
374
+ Initialize the layer.
375
+
376
+ Args:
377
+ d_model (int): The dimensionality of the input and output features.
378
+ n_heads (int): The number of attention heads.
379
+ device (str): The device to use (cpu or cuda).
380
+ attn_eps (float): The small value added to the denominator in the attention normalization for numerical stability.
381
+ dropout (float): The dropout rate to be applied after the sublayer.
382
+ ffn_eps (float, optional): The small value added to the denominator in the feed-forward normalization for numerical stability. Defaults to 1e-6.
383
+ """
384
+ super().__init__()
385
+ self.d_model = d_model
386
+ self.n_heads = n_heads
387
+ self.device = device
388
+
389
+ self.attention = SimpleMultiHeadAttention(
390
+ dim=self.d_model,
391
+ num_heads=self.n_heads,
392
+ device=self.device,
393
+ dropout=dropout,
394
+ bias=False
395
+ )
396
+
397
+ # Use 4*d_model as hidden dimension, following standard transformer practice
398
+ self.ffn = SwiGLUFFN(input_dim=self.d_model, hidden_dim = 4 * self.d_model)
399
+
400
+ self.attn_norm = RMSNorm(dim=d_model, eps=attn_eps)
401
+ self.ffn_norm = RMSNorm(dim=d_model, eps=ffn_eps)
402
+
403
+ def forward(self, x: torch.Tensor, freqs_complex: torch.Tensor, start_pos: int):
404
+ """
405
+ Computes the output of the layer.
406
+
407
+ Args:
408
+ x (torch.Tensor): The input tensor of shape (batch_size, seq_len, d_model).
409
+ freqs_complex (torch.Tensor): The complex position frequencies tensor of shape (seq_len, d_head/2).
410
+ start_pos (int): The starting position of the sequence.
411
+
412
+ Returns:
413
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model)
414
+ """
415
+ # Attention block with residual connection
416
+ h = x + self.attention(self.attn_norm(x), freqs_complex=freqs_complex, start_pos=start_pos)
417
+
418
+ # FFN block with residual connection
419
+ # SwiGLUFFN returns only the output tensor, no router loss for dense model
420
+ ffn_output = self.ffn(self.ffn_norm(h))
421
+ out = h + ffn_output
422
+
423
+
424
+ return out
425
+
426
+
427
+
428
+
429
+ class tiny_gpt(nn.Module):
430
+ def __init__(self,args:ModelArgs):
431
+ super(tiny_gpt, self).__init__()
432
+ self.args=args
433
+ self.vocab_size=args.vocab_size
434
+ self.n_layers=args.n_layers
435
+ self.tok_embedding=nn.Embedding(self.vocab_size,args.d_model)
436
+ self.layers=clones(layer(d_model=args.d_model,
437
+ n_heads=args.n_heads,
438
+ device=args.device,
439
+ attn_eps=args.attn_eps,
440
+ dropout=args.attn_dropout,
441
+ ffn_eps=args.ffn_eps),self.n_layers)
442
+ self.norm=RMSNorm(args.d_model,eps=args.norm_eps)
443
+
444
+ self.output=nn.Linear(in_features=args.d_model,out_features=self.vocab_size)
445
+
446
+ self.freqs_complex=precompute_theta_pos_frequencies(d_head=args.d_model//args.n_heads,seq_len=args.max_seq_len,device=args.device)
447
+
448
+
449
+ def forward(self,x:torch.Tensor,start_pos:int):
450
+ batch_size,seq_len=x.shape
451
+ h=self.tok_embedding(x)
452
+ freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
453
+
454
+ for layer in self.layers:
455
+ h = layer(h,freqs_complex=freqs_complex,start_pos=start_pos)
456
+
457
+ h=self.norm(h)
458
+ out=self.output(h).float()
459
+
460
+ return out
461
+
462
+
463
+
464
+
465
+ class TinyGPTForCausalLM(PreTrainedModel, GenerationMixin):
466
+ config_class = TinyGPTConfig
467
+ base_model_prefix = "gpt_model"
468
+
469
+ def __init__(self, config):
470
+ super().__init__(config)
471
+ args = ModelConfig(
472
+ vocab_size=config.vocab_size,
473
+ d_model=config.d_model,
474
+ d_head=config.d_head,
475
+ n_heads=config.n_heads,
476
+ n_layers=config.n_layers,
477
+ max_seq_len=config.max_seq_len,
478
+ norm_eps=config.norm_eps,
479
+ attn_eps=config.attn_eps,
480
+ ffn_eps=config.ffn_eps,
481
+ attn_dropout=config.attn_dropout,
482
+ device=config.device,
483
+ )
484
+ self.model = tiny_gpt(args=args)
485
+ self.config = config
486
+ self.post_init()
487
+
488
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
489
+
490
+ outputs = self.model(input_ids, start_pos=0)
491
+
492
+ return CausalLMOutputWithPast(
493
+ loss=None,
494
+ logits=outputs,
495
+ attentions=None,
496
+ )
497
+
498
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
499
+ return {
500
+ "input_ids": input_ids,
501
+ }