VJyzCELERY commited on
Commit
3920b5f
·
1 Parent(s): cb5e58d

Added application file

Browse files
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from src.model import Config,GPT
4
+ from src.inference import GPTInfer
5
+ import tiktoken
6
+ import torch
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ os.environ['GRADIO_DEFAULT_CONCURRENCY_LIMIT']="1"
11
+
12
+ device = 'cpu'
13
+ if torch.cuda.is_available():
14
+ device = 'cuda'
15
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
16
+ device = 'mps'
17
+ print(f'using device: {device}')
18
+ model_path = hf_hub_download(
19
+ repo_id="VJyzCELERY/GPT2-GutenbergStoryGenerator",
20
+ filename="GPT2-GutenbergStoryGenerator.pt"
21
+ )
22
+ checkpoint = torch.load(model_path, weights_only=False)
23
+ model = GPT(config=checkpoint['config'])
24
+ model.load_state_dict(checkpoint['model'])
25
+ model = model.to(device)
26
+ token_encoder = tiktoken.get_encoding('gpt2')
27
+ generator = GPTInfer(model, token_encoder, device)
28
+
29
+ def generate_story(
30
+ prompt,
31
+ max_new_tokens=50,
32
+ seed=42,
33
+ temperature=0.8,
34
+ top_k=None,
35
+ top_p=0.9,
36
+ repetition_penalty=1.2,
37
+ frequency_penalty=0.6,
38
+ no_repeat_ngram_size=3,
39
+ longer_story=True,
40
+ context_window=512
41
+ ):
42
+ if not prompt.strip():
43
+ return prompt, gr.update()
44
+
45
+ if top_k <= 0:
46
+ top_k = None
47
+
48
+ output_text = prompt
49
+ last_piece = ""
50
+ # print(f'{prompt}',end='',flush=True)
51
+ yield gr.update(value=output_text,interactive=False), gr.update(interactive=False)
52
+ for piece in generator.generate(
53
+ prompt,
54
+ max_new_tokens=max_new_tokens,
55
+ seed=seed,
56
+ temperature=temperature,
57
+ top_k=top_k,
58
+ top_p=top_p,
59
+ repetition_penalty=repetition_penalty,
60
+ frequency_penalty=frequency_penalty,
61
+ no_repeat_ngram_size=no_repeat_ngram_size,
62
+ longer_story=longer_story,
63
+ context_window=context_window
64
+ ):
65
+ if piece == last_piece:
66
+ continue
67
+ last_piece = piece
68
+ output_text += piece
69
+ # print(f'{piece}',end='',flush=True)
70
+ yield output_text, gr.update(interactive=False)
71
+
72
+ yield gr.update(value=output_text,interactive=True), gr.update(interactive=True)
73
+
74
+ with gr.Blocks(title="Story Generator") as demo:
75
+
76
+ gr.Markdown("# ✨ Story Generator ✨")
77
+ gr.Markdown(
78
+ "Ketik prompt atau cerita awal di bawah ini. "
79
+ "Tekan **Generate** untuk melanjutkan cerita. "
80
+ "Anda dapat mengedit hasil cerita dan generate lagi untuk melanjutkan."
81
+ )
82
+
83
+ story_box = gr.Textbox(
84
+ label="Story / Prompt",
85
+ lines=15,
86
+ placeholder="Tulis prompt atau awal cerita di sini...",
87
+ )
88
+
89
+ generate_btn = gr.Button("Generate Story", variant="primary")
90
+
91
+ with gr.Accordion("Generation Settings", open=False):
92
+ context_window = gr.Slider(
93
+ minimum=128,
94
+ maximum=2048,
95
+ value=512,
96
+ step=64,
97
+ label="Context Window (tokens to use from end of text)",
98
+ info="Limits how much previous text is used. Lower = faster but less context."
99
+ )
100
+
101
+ max_new_tokens = gr.Slider(
102
+ minimum=20,
103
+ maximum=2048,
104
+ value=1024,
105
+ step=10,
106
+ label="Max New Tokens"
107
+ )
108
+
109
+ seed = gr.Number(
110
+ value=42,
111
+ label="Seed"
112
+ )
113
+
114
+ temperature = gr.Slider(
115
+ minimum=0.1,
116
+ maximum=1.0,
117
+ value=0.8,
118
+ step=0.05,
119
+ label="Temperature"
120
+ )
121
+
122
+ top_k = gr.Slider(
123
+ minimum=0,
124
+ maximum=200,
125
+ value=0,
126
+ step=1,
127
+ label="Top-K (0 = disabled)"
128
+ )
129
+
130
+ top_p = gr.Slider(
131
+ minimum=0.0,
132
+ maximum=1.0,
133
+ value=0.9,
134
+ step=0.01,
135
+ label="Top-P"
136
+ )
137
+
138
+ repetition_penalty = gr.Slider(
139
+ minimum=1.0,
140
+ maximum=2.0,
141
+ value=1.2,
142
+ step=0.05,
143
+ label="Repetition Penalty"
144
+ )
145
+
146
+ frequency_penalty = gr.Slider(
147
+ minimum=0.0,
148
+ maximum=1.0,
149
+ value=0.6,
150
+ step=0.05,
151
+ label="Frequency Penalty"
152
+ )
153
+
154
+ no_repeat = gr.Slider(
155
+ minimum=1,
156
+ maximum=10,
157
+ value=3,
158
+ step=1,
159
+ label="No-Repeat N-gram Size"
160
+ )
161
+
162
+
163
+ generate_btn.click(
164
+ fn=generate_story,
165
+ inputs=[
166
+ story_box,
167
+ max_new_tokens,
168
+ seed,
169
+ temperature,
170
+ top_k,
171
+ top_p,
172
+ repetition_penalty,
173
+ frequency_penalty,
174
+ no_repeat,
175
+ gr.Checkbox(value=True, visible=False),
176
+ context_window
177
+ ],
178
+ outputs=[story_box, generate_btn]
179
+ )
180
+
181
+ #Run App
182
+ if __name__ == "__main__":
183
+
184
+ demo.launch(share=False)
src/__pycache__/inference.cpython-312.pyc ADDED
Binary file (9.8 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (13.9 kB). View file
 
src/__pycache__/trainer.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
src/inference.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def concat(prev, new):
5
+ if prev and prev[-1].isalnum() and new and new[0].isalnum():
6
+ return prev + " " + new
7
+ return prev + new
8
+
9
+ class GPTInfer:
10
+ def __init__(self, model, token_encoder, device):
11
+ self.model = model
12
+ self.token_encoder = token_encoder
13
+ self.device = device
14
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
15
+
16
+ def get_token_length(self, text):
17
+ return len(self.token_encoder.encode(text, allowed_special={"<|endoftext|>"}))
18
+
19
+ def apply_frequency_penalty_and_blocking(
20
+ self,
21
+ logits,
22
+ gen_tokens,
23
+ frequency_penalty=0.5,
24
+ no_repeat_ngram_size=3,
25
+ ):
26
+ logits = logits.clone().float()
27
+
28
+ if frequency_penalty and frequency_penalty > 0.0:
29
+ counts = {}
30
+ for t in gen_tokens[0].tolist():
31
+ counts[t] = counts.get(t, 0) + 1
32
+ if counts:
33
+ vocab_size = logits.shape[-1]
34
+ penalty = torch.zeros(vocab_size, dtype=logits.dtype, device=logits.device)
35
+ for tok, c in counts.items():
36
+ if 0 <= tok < vocab_size:
37
+ penalty[tok] = float(c) * float(frequency_penalty)
38
+ logits = logits - penalty.unsqueeze(0)
39
+
40
+ if no_repeat_ngram_size and no_repeat_ngram_size > 0:
41
+ n = no_repeat_ngram_size
42
+ cur = gen_tokens[0].tolist()
43
+ if len(cur) >= n - 1:
44
+ banned_next = set()
45
+ for i in range(len(cur) - (n - 1)):
46
+ ngram = tuple(cur[i:i + n])
47
+ prefix = tuple(ngram[:-1])
48
+ banned_next.add(ngram[-1])
49
+ last_prefix = tuple(cur[-(n - 1):]) if n > 1 else tuple()
50
+ for i in range(len(cur) - (n - 1)):
51
+ if tuple(cur[i:i + (n - 1)]) == last_prefix and i + (n - 1) < len(cur):
52
+ banned_token = cur[i + (n - 1)]
53
+ if 0 <= banned_token < logits.shape[-1]:
54
+ logits[0, banned_token] = -1e9
55
+
56
+ return logits
57
+
58
+ def sample_next_token(
59
+ self,
60
+ logits,
61
+ gen_tokens,
62
+ seed_rng,
63
+ temperature=0.8,
64
+ top_k=None,
65
+ top_p=0.9,
66
+ repetition_penalty=1.2,
67
+ frequency_penalty=0.5,
68
+ no_repeat_ngram_size=3,
69
+ recent_tokens_window=200,
70
+ ):
71
+ logits = logits.clone().float()
72
+
73
+ recent = gen_tokens[0, -recent_tokens_window:].tolist()
74
+ if repetition_penalty is not None and repetition_penalty != 1.0:
75
+ for t in set(recent):
76
+ if 0 <= t < logits.shape[-1]:
77
+ logits[0, t] /= float(repetition_penalty)
78
+
79
+ logits = self.apply_frequency_penalty_and_blocking(
80
+ logits,
81
+ gen_tokens,
82
+ frequency_penalty=frequency_penalty,
83
+ no_repeat_ngram_size=no_repeat_ngram_size,
84
+ )
85
+
86
+ if temperature is not None and temperature != 1.0:
87
+ logits = logits / float(temperature)
88
+
89
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
90
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
91
+
92
+ if top_k is not None:
93
+ k = min(int(top_k), sorted_logits.shape[-1])
94
+ sorted_logits = sorted_logits[:, :k]
95
+ sorted_idx = sorted_idx[:, :k]
96
+ sorted_probs = sorted_probs[:, :k]
97
+
98
+ if top_p is not None and 0.0 < top_p < 1.0:
99
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
100
+ mask = cum_probs <= top_p
101
+ if not mask.any():
102
+ mask[0, 0] = True
103
+ keep_count = int(mask.sum(dim=-1).item())
104
+ sorted_probs = sorted_probs[:, :keep_count]
105
+ sorted_idx = sorted_idx[:, :keep_count]
106
+
107
+ sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12)
108
+ next_index_in_sorted = torch.multinomial(sorted_probs, 1, generator=seed_rng)
109
+ next_tok = sorted_idx.gather(-1, next_index_in_sorted)
110
+
111
+ return int(next_tok.item())
112
+
113
+ def generate(
114
+ self,
115
+ prompt,
116
+ max_new_tokens=50,
117
+ seed=42,
118
+ longer_story=True,
119
+ temperature=0.8,
120
+ top_k=None,
121
+ top_p=0.9,
122
+ repetition_penalty=1.2,
123
+ frequency_penalty=0.5,
124
+ no_repeat_ngram_size=3,
125
+ context_window=None,
126
+ stream=True,
127
+ ):
128
+ self.model.eval()
129
+
130
+ tokens = self.token_encoder.encode(prompt)
131
+ if context_window is not None and len(tokens) > context_window:
132
+ tokens = tokens[-context_window:]
133
+
134
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)
135
+ gen_tokens = tokens.clone()
136
+
137
+ if seed is not None:
138
+ sample_rng = torch.Generator(device=self.device).manual_seed(seed)
139
+ else:
140
+ sample_rng = torch.Generator(device=self.device)
141
+
142
+ eos_id = self.token_encoder.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]
143
+ context_len = self.model.config.context_length
144
+ new_tokens_generated = 0
145
+ HARD_MAX_TOTAL = context_len + max_new_tokens + 10
146
+
147
+ while new_tokens_generated < max_new_tokens and gen_tokens.shape[1] < HARD_MAX_TOTAL:
148
+ if gen_tokens.shape[1] > context_len:
149
+ idx_cond = gen_tokens[:, -context_len:]
150
+ else:
151
+ idx_cond = gen_tokens
152
+
153
+ with torch.no_grad():
154
+ try:
155
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
156
+ logits, _ = self.model(idx_cond)
157
+ except Exception:
158
+ logits, _ = self.model(idx_cond)
159
+
160
+ next_logits = logits[:, -1:, :].squeeze(1)
161
+
162
+ if longer_story and new_tokens_generated < 5:
163
+ next_logits[0, eos_id] = next_logits[0, eos_id] / 4.0
164
+
165
+ next_token_id = self.sample_next_token(
166
+ logits=next_logits,
167
+ gen_tokens=gen_tokens,
168
+ seed_rng=sample_rng,
169
+ temperature=temperature,
170
+ top_k=top_k,
171
+ top_p=top_p,
172
+ repetition_penalty=repetition_penalty,
173
+ frequency_penalty=frequency_penalty,
174
+ no_repeat_ngram_size=no_repeat_ngram_size,
175
+ recent_tokens_window=200,
176
+ )
177
+
178
+ if next_token_id == eos_id:
179
+ break
180
+
181
+ next_tok_tensor = torch.tensor([[next_token_id]], dtype=torch.long).to(self.device)
182
+ gen_tokens = torch.cat([gen_tokens, next_tok_tensor], dim=1)
183
+ new_tokens_generated += 1
184
+
185
+ if stream:
186
+ yield self.token_encoder.decode([next_token_id], errors='ignore')
187
+
188
+ if not stream:
189
+ yield self.token_encoder.decode(gen_tokens[0, :].tolist(), errors='ignore')
190
+
191
+ def print_stream(
192
+ self,
193
+ prompt,
194
+ max_new_tokens=200,
195
+ seed=42,
196
+ longer_story=True,
197
+ temperature=0.8,
198
+ top_k=None,
199
+ top_p=0.9,
200
+ repetition_penalty=1.2,
201
+ frequency_penalty=0.6,
202
+ no_repeat_ngram_size=3,
203
+ context_window=512,
204
+ ):
205
+ text = prompt
206
+ last_piece = ""
207
+ print(prompt, end="", flush=True)
208
+ for piece in self.generate(
209
+ prompt,
210
+ max_new_tokens=max_new_tokens,
211
+ seed=seed,
212
+ longer_story=longer_story,
213
+ temperature=temperature,
214
+ top_k=top_k,
215
+ top_p=top_p,
216
+ repetition_penalty=repetition_penalty,
217
+ frequency_penalty=frequency_penalty,
218
+ no_repeat_ngram_size=no_repeat_ngram_size,
219
+ context_window=context_window,
220
+ ):
221
+ if piece == last_piece:
222
+ continue
223
+ last_piece = piece
224
+ text = concat(text, piece)
225
+ print(piece, end="", flush=True)
226
+ return text
src/model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as f
4
+ from dataclasses import dataclass
5
+ import inspect
6
+
7
+ @dataclass
8
+ class Config:
9
+ context_length : int = 1024
10
+ vocab_size: int = 50257
11
+ num_layers : int = 12
12
+ embedding_dim : int = 768
13
+ num_heads: int = 12
14
+
15
+ class MultiHeadAttention(nn.Module):
16
+ def __init__(self,config : Config,masked=False):
17
+ super(MultiHeadAttention,self).__init__()
18
+ self.num_heads = config.num_heads
19
+ self.masked = masked
20
+ self.embedding_dim = config.embedding_dim
21
+ self.c_attention = nn.Linear(config.embedding_dim,3*config.embedding_dim)
22
+ self.c_projection = nn.Linear(config.embedding_dim,config.embedding_dim)
23
+ self.c_projection.SCALE_INIT = 1.0
24
+
25
+ def forward(self,x):
26
+ B, T, C = x.shape
27
+ QKV = self.c_attention(x)
28
+ Query_q,Key_k,Value_v = QKV.split(self.embedding_dim,dim=-1)
29
+ Query_q = Query_q.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
30
+ Key_k = Key_k.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
31
+ Value_v = Value_v.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
32
+
33
+ # out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=True)
34
+ if self.masked:
35
+ out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=True)
36
+ else:
37
+ out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=False)
38
+ out = out.transpose(1,2).contiguous().view(B,T,C)
39
+ return self.c_projection(out)
40
+
41
+ class MLP(nn.Module):
42
+ def __init__(self,config : Config):
43
+ super(MLP,self).__init__()
44
+ self.c_fc = nn.Linear(config.embedding_dim,4*config.embedding_dim)
45
+ self.gelu = nn.GELU(approximate='tanh')
46
+ self.c_projection = nn.Linear(4*config.embedding_dim,config.embedding_dim)
47
+ self.c_projection.SCALE_INIT = 1.0
48
+ def forward(self,x):
49
+ x = self.c_fc(x)
50
+ x = self.gelu(x)
51
+ x = self.c_projection(x)
52
+ return x
53
+
54
+ class DecoderBlock(nn.Module):
55
+ def __init__(self,config : Config):
56
+ """Decoder block without the encoder output"""
57
+ super(DecoderBlock,self).__init__()
58
+ self.masked_attention = MultiHeadAttention(config,masked=True)
59
+ self.layer_norm1 = nn.LayerNorm(config.embedding_dim)
60
+ # self.attention = MultiHeadAttention(config,masked=False)
61
+ # self.layer_norm2 = nn.LayerNorm(config.embedding_dim)
62
+ self.mlp = MLP(config)
63
+ self.layer_norm3 = nn.LayerNorm(config.embedding_dim)
64
+
65
+ def forward(self,x):
66
+ x = x + self.masked_attention(self.layer_norm1(x))
67
+ # x = x + self.attention(self.layer_norm2(x))
68
+ x = x + self.mlp(self.layer_norm3(x))
69
+ return x
70
+
71
+ class TransformerDecoder(nn.Module):
72
+ def __init__(self,config : Config):
73
+ super(TransformerDecoder,self).__init__()
74
+ self.config = config
75
+ self.word_token_embedding = nn.Embedding(self.config.vocab_size,self.config.embedding_dim)
76
+ self.word_position_embedding = nn.Embedding(self.config.context_length,self.config.embedding_dim)
77
+ layers = [DecoderBlock(config) for _ in range(config.num_layers)]
78
+ self.hidden_layers = nn.Sequential(*layers)
79
+ self.layer_norm = nn.LayerNorm(self.config.embedding_dim)
80
+
81
+ def forward(self,idx):
82
+ B,T = idx.shape
83
+ pos = torch.arange(0,T,dtype=torch.long,device=idx.device)
84
+ pos_embed = self.word_position_embedding(pos)
85
+ token_embed = self.word_token_embedding(idx)
86
+ x = pos_embed + token_embed
87
+ x = self.hidden_layers(x)
88
+ x = self.layer_norm(x)
89
+ return x
90
+
91
+ class GPT(nn.Module):
92
+ def __init__(self,config : Config):
93
+ super(GPT,self).__init__()
94
+ self.config=config
95
+ self.transformerDecoder = TransformerDecoder(config)
96
+ self.language_modeling_head = nn.Linear(config.embedding_dim,config.vocab_size,bias=False)
97
+ self.transformerDecoder.word_token_embedding.weight = self.language_modeling_head.weight
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self,module):
101
+ if isinstance(module,nn.Linear):
102
+ std=0.02
103
+ if hasattr(module,'SCALE_INIT'):
104
+ std /= (2*self.config.num_layers)**0.5
105
+ torch.nn.init.normal_(module.weight,mean=0,std=std)
106
+ if module.bias is not None:
107
+ torch.nn.init.zeros_(module.bias)
108
+ elif isinstance(module,nn.Embedding):
109
+ torch.nn.init.normal_(module.weight,mean=0,std=0.02)
110
+
111
+ def forward(self,idx,targets=None):
112
+ x = self.transformerDecoder(idx)
113
+ logits = self.language_modeling_head(x)
114
+ loss = None
115
+ if targets is not None:
116
+ loss = f.cross_entropy(logits.view(-1,logits.shape[-1]),targets.view(-1))
117
+ return logits,loss
118
+ @torch.no_grad()
119
+ def generate(self, idx, max_new_tokens=50, temperature=0.8, top_k=None, do_sample=False, eos_token_id=None):
120
+ self.eval()
121
+
122
+ B, T = idx.shape
123
+ device = idx.device
124
+ context_len = self.config.context_length
125
+
126
+ if T > context_len:
127
+ idx = idx[:, -context_len:]
128
+ T = idx.shape[1]
129
+
130
+ generated = idx.clone()
131
+
132
+ for _ in range(max_new_tokens):
133
+ input_ids = generated[:, -context_len:]
134
+
135
+ logits, _ = self.forward(input_ids, targets=None)
136
+ next_logits = logits[:, -1, :]
137
+
138
+ if temperature != 1.0 and temperature > 0.0:
139
+ next_logits = next_logits / temperature
140
+
141
+ if do_sample:
142
+ if top_k is not None and top_k > 0:
143
+ vals, idxs = next_logits.topk(top_k, dim=-1)
144
+ min_vals = vals[:, -1].unsqueeze(-1)
145
+ mask = next_logits < min_vals
146
+ next_logits = next_logits.masked_fill(mask, float('-inf'))
147
+
148
+ probs = torch.softmax(next_logits, dim=-1)
149
+ next_token = torch.multinomial(probs, num_samples=1)
150
+ else:
151
+ next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
152
+
153
+ generated = torch.cat([generated, next_token], dim=1)
154
+
155
+ if eos_token_id is not None:
156
+ if (generated == eos_token_id).any(dim=1).all():
157
+ break
158
+
159
+ return generated
160
+ def configure_optimizer(self,weight_decay,lr,device_type,master_process):
161
+ param_dict = {pn:p for pn, p in self.named_parameters() if p.requires_grad}
162
+
163
+ decay_params = [p for pn, p in param_dict.items() if p.dim() >=2]
164
+ nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2]
165
+ optim_groups = [
166
+ {'params':decay_params,'weight_decay':weight_decay},
167
+ {'params':nodecay_params,'weight_decay':0.0}
168
+ ]
169
+ num_decay_params = sum(p.numel() for p in decay_params)
170
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
171
+ if master_process:
172
+ print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters')
173
+ print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters')
174
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
175
+ use_fused = fused_available and device_type == 'cuda'
176
+ if master_process:
177
+ print(f'using fused AdamW optimizer: {use_fused}')
178
+ optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
179
+ return optimizer
src/trainer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as f
4
+ from dataclasses import dataclass
5
+ import time
6
+ import os
7
+ from src.model import GPT,Config
8
+ from sacrebleu import corpus_bleu
9
+ from rouge_score import rouge_scorer
10
+ import numpy as np
11
+ import math
12
+
13
+ torch.set_float32_matmul_precision('high')
14
+
15
+ def repetition_rate(text, n=3):
16
+ tokens = text.split()
17
+ if len(tokens) < n:
18
+ return 0.0
19
+ ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
20
+ return (len(ngrams) - len(set(ngrams))) / len(ngrams)
21
+
22
+ def distinct_n(text, n=1):
23
+ tokens = text.split()
24
+ if len(tokens) < n:
25
+ return 0.0
26
+
27
+ ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
28
+ return len(set(ngrams)) / len(ngrams)
29
+
30
+
31
+ def compute_self_bleu(generated_texts):
32
+ if len(generated_texts) < 2:
33
+ return 0.0
34
+
35
+ scores = []
36
+ N = len(generated_texts)
37
+
38
+ for i in range(N):
39
+ hyp = generated_texts[i]
40
+ refs = generated_texts[:i] + generated_texts[i+1:]
41
+
42
+ bleu = corpus_bleu([hyp], [refs]).score
43
+ scores.append(bleu)
44
+
45
+ return sum(scores) / len(scores)
46
+
47
+ class Trainer:
48
+ def __init__(self,model : GPT,optimizer,train_loader,val_loader,token_encoder,eval_freq,grad_accum_steps,device,master_process,logpath):
49
+ self.model = model
50
+ self.optimizer = optimizer
51
+ self.train_loader = train_loader
52
+ self.val_loader = val_loader
53
+ self.token_encoder = token_encoder
54
+ self.master_process = master_process
55
+ self.eval_freq = eval_freq
56
+ self.grad_accum_steps = grad_accum_steps
57
+ self.device = device
58
+ self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
59
+ self.logpath=logpath
60
+
61
+
62
+ def train(self,max_steps,warmup_steps,max_lr,min_lr):
63
+ history={
64
+ 'val_losses':[],
65
+ 'perplexities':[],
66
+ 'train_losses':[]
67
+ }
68
+ for step in range(max_steps):
69
+ val_loss = None
70
+ perplexity=None
71
+ t0 = time.time()
72
+ self.is_last_step = (step == max_steps-1)
73
+ self.model.train()
74
+ self.optimizer.zero_grad()
75
+ batch_loss = 0.0
76
+
77
+ for mini_step in range(self.grad_accum_steps):
78
+ inp, target = self.train_loader.next_batch()
79
+ inp, target = inp.to(self.device),target.to(self.device)
80
+
81
+ with torch.autocast(device_type=self.device_type,dtype=torch.bfloat16):
82
+ logits,loss = self.model(inp,target)
83
+ loss /=self.grad_accum_steps
84
+ batch_loss+=loss.detach()
85
+ loss.backward()
86
+ norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
87
+ lr = self.estimate_lr(step,warmup_steps,max_steps,max_lr,min_lr)
88
+ for param_group in self.optimizer.param_groups:
89
+ param_group['lr'] = lr
90
+ self.optimizer.step()
91
+ if self.device_type == 'cuda':
92
+ torch.cuda.synchronize()
93
+ dt = (time.time() - t0) * 1000.0 # in ms
94
+ tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * 1
95
+ tokens_per_sec = tokens_processed / dt
96
+
97
+ if step % self.eval_freq == 0 or self.is_last_step:
98
+ val_loss,perplexity = self.evaluate_validation(step)
99
+ history['val_losses'].append(val_loss)
100
+ history['perplexities'].append(perplexity)
101
+
102
+ history['train_losses'].append(batch_loss.item())
103
+ if self.master_process:
104
+ print(f'step {step:4d} | train loss: {batch_loss.item():.2f}{f' | val loss: {val_loss:.2f}' if val_loss is not None else ''}{f' | perplexity: {perplexity:.2f}' if perplexity is not None else ''} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
105
+ with open(self.logpath, 'a') as f:
106
+ f.write(f'{step} train {batch_loss.item():.6f}\n')
107
+
108
+ evaluation =self.evaluate_text_metrics(
109
+ max_samples=60,
110
+ gen_len=256,
111
+ do_sample=False,
112
+ top_k=None,
113
+ temperature=0.2,
114
+ eos_token_id=None
115
+ )
116
+ return history,evaluation
117
+
118
+ def evaluate_text_metrics(self, max_samples=100, gen_len=50, do_sample=False, top_k=None, temperature=1.0, eos_token_id=None):
119
+ self.model.eval()
120
+ self.val_loader.reset()
121
+
122
+ hyps = []
123
+ refs = []
124
+ scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
125
+
126
+ samples_collected = 0
127
+ while samples_collected < max_samples:
128
+ try:
129
+ inp, target = self.val_loader.next_batch()
130
+ except StopIteration:
131
+ break
132
+
133
+ inp = inp.to(self.device)
134
+ target = target.to(self.device)
135
+
136
+ if inp.shape[1] > self.model.config.context_length:
137
+ inp = inp[:, -self.model.config.context_length:]
138
+
139
+ with torch.no_grad():
140
+ generated = self.model.generate(
141
+ inp,
142
+ max_new_tokens=gen_len,
143
+ temperature=temperature,
144
+ top_k=top_k,
145
+ do_sample=do_sample,
146
+ eos_token_id=eos_token_id
147
+ )
148
+
149
+ B = generated.shape[0]
150
+ for i in range(B):
151
+ gen_ids = generated[i, inp.shape[1]:].tolist()
152
+
153
+ pred_text = self.token_encoder.decode(gen_ids)
154
+ ref_text = self.token_encoder.decode(target[i].tolist())
155
+
156
+ hyps.append(pred_text)
157
+ refs.append(ref_text)
158
+ samples_collected += 1
159
+ if samples_collected >= max_samples:
160
+ break
161
+
162
+ if len(hyps) == 0:
163
+ return 0.0, 0.0
164
+
165
+ rep_scores = []
166
+ distinct1_scores = []
167
+ distinct2_scores = []
168
+
169
+ for txt in hyps:
170
+ rep_scores.append(repetition_rate(txt, n=3))
171
+ distinct1_scores.append(distinct_n(txt, n=1))
172
+ distinct2_scores.append(distinct_n(txt, n=2))
173
+
174
+ avg_rep = sum(rep_scores) / len(rep_scores)
175
+ avg_d1 = sum(distinct1_scores) / len(distinct1_scores)
176
+ avg_d2 = sum(distinct2_scores) / len(distinct2_scores)
177
+
178
+
179
+ bleu = corpus_bleu(hyps, refs).score
180
+ self_bleu=compute_self_bleu(hyps)
181
+ rouge_scores = []
182
+ for h, r in zip(hyps, refs):
183
+ sc = scorer.score(r, h)['rougeL'].fmeasure
184
+ rouge_scores.append(sc)
185
+ rouge_l = sum(rouge_scores) / len(rouge_scores)
186
+
187
+ if self.master_process:
188
+ print(f"[Text Eval] samples={len(hyps)} BLEU={bleu:.2f} ROUGE-L={rouge_l:.4f} SELF-BLEU={self_bleu:.2f} REP={avg_rep:.4f} D1={avg_d1:.4f} D2={avg_d2:.4f}")
189
+ with open(self.logpath, 'a') as f:
190
+ f.write(f"eval samples={len(hyps)} BLEU={bleu:.2f} ROUGE-L={rouge_l:.4f} SELF-BLEU={self_bleu:.2f} REP={avg_rep:.4f} D1={avg_d1:.4f} D2={avg_d2:.4f}\n")
191
+
192
+ return {"bleu":bleu,"rogue-l":rouge_scores,"self-bleu":self_bleu,"repetition":rep_scores,"D1":distinct1_scores,"D2":distinct2_scores}
193
+
194
+ def evaluate_validation(self,step):
195
+ self.model.eval()
196
+ self.val_loader.reset()
197
+ with torch.no_grad():
198
+ val_loss_accum = 0.0
199
+ val_steps = 20
200
+ for _ in range(val_steps):
201
+ inp, target = self.val_loader.next_batch()
202
+ inp, target = inp.to(self.device),target.to(self.device)
203
+
204
+ with torch.autocast(device_type=self.device_type,dtype=torch.bfloat16):
205
+ logits,loss = self.model(inp,target)
206
+ loss /=val_steps
207
+ val_loss_accum+=loss.detach()
208
+
209
+ if self.master_process:
210
+ perplexity = math.exp(val_loss_accum.item())
211
+ with open(self.logpath, 'a') as f:
212
+ f.write(f'{step} val {val_loss_accum.item():.4f}\n')
213
+
214
+ if step > 0 and (step % 10000 == 0 or self.is_last_step):
215
+ raw_model = self.model
216
+ logdir = os.path.dirname(self.logpath)
217
+ ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
218
+ checkpoint = {
219
+ 'model': raw_model.state_dict(),
220
+ 'config': raw_model.config,
221
+ 'step': step,
222
+ 'val_loss': val_loss_accum.item()
223
+ }
224
+ torch.save(checkpoint, ckpt_path)
225
+ return val_loss_accum.item(),perplexity
226
+
227
+ def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
228
+ if step < warmup_steps:
229
+ return max_lr * (step+1) / warmup_steps
230
+ if step > max_steps:
231
+ return min_lr
232
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
233
+ assert 0 <= decay_ratio <= 1
234
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
235
+ return min_lr + coeff * (max_lr - min_lr)