Spaces:
Build error
Build error
Commit ·
3920b5f
1
Parent(s): cb5e58d
Added application file
Browse files- app.py +184 -0
- src/__pycache__/inference.cpython-312.pyc +0 -0
- src/__pycache__/model.cpython-312.pyc +0 -0
- src/__pycache__/trainer.cpython-312.pyc +0 -0
- src/inference.py +226 -0
- src/model.py +179 -0
- src/trainer.py +235 -0
app.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from src.model import Config,GPT
|
| 4 |
+
from src.inference import GPTInfer
|
| 5 |
+
import tiktoken
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
os.environ['GRADIO_DEFAULT_CONCURRENCY_LIMIT']="1"
|
| 11 |
+
|
| 12 |
+
device = 'cpu'
|
| 13 |
+
if torch.cuda.is_available():
|
| 14 |
+
device = 'cuda'
|
| 15 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 16 |
+
device = 'mps'
|
| 17 |
+
print(f'using device: {device}')
|
| 18 |
+
model_path = hf_hub_download(
|
| 19 |
+
repo_id="VJyzCELERY/GPT2-GutenbergStoryGenerator",
|
| 20 |
+
filename="GPT2-GutenbergStoryGenerator.pt"
|
| 21 |
+
)
|
| 22 |
+
checkpoint = torch.load(model_path, weights_only=False)
|
| 23 |
+
model = GPT(config=checkpoint['config'])
|
| 24 |
+
model.load_state_dict(checkpoint['model'])
|
| 25 |
+
model = model.to(device)
|
| 26 |
+
token_encoder = tiktoken.get_encoding('gpt2')
|
| 27 |
+
generator = GPTInfer(model, token_encoder, device)
|
| 28 |
+
|
| 29 |
+
def generate_story(
|
| 30 |
+
prompt,
|
| 31 |
+
max_new_tokens=50,
|
| 32 |
+
seed=42,
|
| 33 |
+
temperature=0.8,
|
| 34 |
+
top_k=None,
|
| 35 |
+
top_p=0.9,
|
| 36 |
+
repetition_penalty=1.2,
|
| 37 |
+
frequency_penalty=0.6,
|
| 38 |
+
no_repeat_ngram_size=3,
|
| 39 |
+
longer_story=True,
|
| 40 |
+
context_window=512
|
| 41 |
+
):
|
| 42 |
+
if not prompt.strip():
|
| 43 |
+
return prompt, gr.update()
|
| 44 |
+
|
| 45 |
+
if top_k <= 0:
|
| 46 |
+
top_k = None
|
| 47 |
+
|
| 48 |
+
output_text = prompt
|
| 49 |
+
last_piece = ""
|
| 50 |
+
# print(f'{prompt}',end='',flush=True)
|
| 51 |
+
yield gr.update(value=output_text,interactive=False), gr.update(interactive=False)
|
| 52 |
+
for piece in generator.generate(
|
| 53 |
+
prompt,
|
| 54 |
+
max_new_tokens=max_new_tokens,
|
| 55 |
+
seed=seed,
|
| 56 |
+
temperature=temperature,
|
| 57 |
+
top_k=top_k,
|
| 58 |
+
top_p=top_p,
|
| 59 |
+
repetition_penalty=repetition_penalty,
|
| 60 |
+
frequency_penalty=frequency_penalty,
|
| 61 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 62 |
+
longer_story=longer_story,
|
| 63 |
+
context_window=context_window
|
| 64 |
+
):
|
| 65 |
+
if piece == last_piece:
|
| 66 |
+
continue
|
| 67 |
+
last_piece = piece
|
| 68 |
+
output_text += piece
|
| 69 |
+
# print(f'{piece}',end='',flush=True)
|
| 70 |
+
yield output_text, gr.update(interactive=False)
|
| 71 |
+
|
| 72 |
+
yield gr.update(value=output_text,interactive=True), gr.update(interactive=True)
|
| 73 |
+
|
| 74 |
+
with gr.Blocks(title="Story Generator") as demo:
|
| 75 |
+
|
| 76 |
+
gr.Markdown("# ✨ Story Generator ✨")
|
| 77 |
+
gr.Markdown(
|
| 78 |
+
"Ketik prompt atau cerita awal di bawah ini. "
|
| 79 |
+
"Tekan **Generate** untuk melanjutkan cerita. "
|
| 80 |
+
"Anda dapat mengedit hasil cerita dan generate lagi untuk melanjutkan."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
story_box = gr.Textbox(
|
| 84 |
+
label="Story / Prompt",
|
| 85 |
+
lines=15,
|
| 86 |
+
placeholder="Tulis prompt atau awal cerita di sini...",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
generate_btn = gr.Button("Generate Story", variant="primary")
|
| 90 |
+
|
| 91 |
+
with gr.Accordion("Generation Settings", open=False):
|
| 92 |
+
context_window = gr.Slider(
|
| 93 |
+
minimum=128,
|
| 94 |
+
maximum=2048,
|
| 95 |
+
value=512,
|
| 96 |
+
step=64,
|
| 97 |
+
label="Context Window (tokens to use from end of text)",
|
| 98 |
+
info="Limits how much previous text is used. Lower = faster but less context."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
max_new_tokens = gr.Slider(
|
| 102 |
+
minimum=20,
|
| 103 |
+
maximum=2048,
|
| 104 |
+
value=1024,
|
| 105 |
+
step=10,
|
| 106 |
+
label="Max New Tokens"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
seed = gr.Number(
|
| 110 |
+
value=42,
|
| 111 |
+
label="Seed"
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
temperature = gr.Slider(
|
| 115 |
+
minimum=0.1,
|
| 116 |
+
maximum=1.0,
|
| 117 |
+
value=0.8,
|
| 118 |
+
step=0.05,
|
| 119 |
+
label="Temperature"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
top_k = gr.Slider(
|
| 123 |
+
minimum=0,
|
| 124 |
+
maximum=200,
|
| 125 |
+
value=0,
|
| 126 |
+
step=1,
|
| 127 |
+
label="Top-K (0 = disabled)"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
top_p = gr.Slider(
|
| 131 |
+
minimum=0.0,
|
| 132 |
+
maximum=1.0,
|
| 133 |
+
value=0.9,
|
| 134 |
+
step=0.01,
|
| 135 |
+
label="Top-P"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
repetition_penalty = gr.Slider(
|
| 139 |
+
minimum=1.0,
|
| 140 |
+
maximum=2.0,
|
| 141 |
+
value=1.2,
|
| 142 |
+
step=0.05,
|
| 143 |
+
label="Repetition Penalty"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
frequency_penalty = gr.Slider(
|
| 147 |
+
minimum=0.0,
|
| 148 |
+
maximum=1.0,
|
| 149 |
+
value=0.6,
|
| 150 |
+
step=0.05,
|
| 151 |
+
label="Frequency Penalty"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
no_repeat = gr.Slider(
|
| 155 |
+
minimum=1,
|
| 156 |
+
maximum=10,
|
| 157 |
+
value=3,
|
| 158 |
+
step=1,
|
| 159 |
+
label="No-Repeat N-gram Size"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
generate_btn.click(
|
| 164 |
+
fn=generate_story,
|
| 165 |
+
inputs=[
|
| 166 |
+
story_box,
|
| 167 |
+
max_new_tokens,
|
| 168 |
+
seed,
|
| 169 |
+
temperature,
|
| 170 |
+
top_k,
|
| 171 |
+
top_p,
|
| 172 |
+
repetition_penalty,
|
| 173 |
+
frequency_penalty,
|
| 174 |
+
no_repeat,
|
| 175 |
+
gr.Checkbox(value=True, visible=False),
|
| 176 |
+
context_window
|
| 177 |
+
],
|
| 178 |
+
outputs=[story_box, generate_btn]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
#Run App
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
|
| 184 |
+
demo.launch(share=False)
|
src/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (9.8 kB). View file
|
|
|
src/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
src/__pycache__/trainer.cpython-312.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
src/inference.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
def concat(prev, new):
|
| 5 |
+
if prev and prev[-1].isalnum() and new and new[0].isalnum():
|
| 6 |
+
return prev + " " + new
|
| 7 |
+
return prev + new
|
| 8 |
+
|
| 9 |
+
class GPTInfer:
|
| 10 |
+
def __init__(self, model, token_encoder, device):
|
| 11 |
+
self.model = model
|
| 12 |
+
self.token_encoder = token_encoder
|
| 13 |
+
self.device = device
|
| 14 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
| 15 |
+
|
| 16 |
+
def get_token_length(self, text):
|
| 17 |
+
return len(self.token_encoder.encode(text, allowed_special={"<|endoftext|>"}))
|
| 18 |
+
|
| 19 |
+
def apply_frequency_penalty_and_blocking(
|
| 20 |
+
self,
|
| 21 |
+
logits,
|
| 22 |
+
gen_tokens,
|
| 23 |
+
frequency_penalty=0.5,
|
| 24 |
+
no_repeat_ngram_size=3,
|
| 25 |
+
):
|
| 26 |
+
logits = logits.clone().float()
|
| 27 |
+
|
| 28 |
+
if frequency_penalty and frequency_penalty > 0.0:
|
| 29 |
+
counts = {}
|
| 30 |
+
for t in gen_tokens[0].tolist():
|
| 31 |
+
counts[t] = counts.get(t, 0) + 1
|
| 32 |
+
if counts:
|
| 33 |
+
vocab_size = logits.shape[-1]
|
| 34 |
+
penalty = torch.zeros(vocab_size, dtype=logits.dtype, device=logits.device)
|
| 35 |
+
for tok, c in counts.items():
|
| 36 |
+
if 0 <= tok < vocab_size:
|
| 37 |
+
penalty[tok] = float(c) * float(frequency_penalty)
|
| 38 |
+
logits = logits - penalty.unsqueeze(0)
|
| 39 |
+
|
| 40 |
+
if no_repeat_ngram_size and no_repeat_ngram_size > 0:
|
| 41 |
+
n = no_repeat_ngram_size
|
| 42 |
+
cur = gen_tokens[0].tolist()
|
| 43 |
+
if len(cur) >= n - 1:
|
| 44 |
+
banned_next = set()
|
| 45 |
+
for i in range(len(cur) - (n - 1)):
|
| 46 |
+
ngram = tuple(cur[i:i + n])
|
| 47 |
+
prefix = tuple(ngram[:-1])
|
| 48 |
+
banned_next.add(ngram[-1])
|
| 49 |
+
last_prefix = tuple(cur[-(n - 1):]) if n > 1 else tuple()
|
| 50 |
+
for i in range(len(cur) - (n - 1)):
|
| 51 |
+
if tuple(cur[i:i + (n - 1)]) == last_prefix and i + (n - 1) < len(cur):
|
| 52 |
+
banned_token = cur[i + (n - 1)]
|
| 53 |
+
if 0 <= banned_token < logits.shape[-1]:
|
| 54 |
+
logits[0, banned_token] = -1e9
|
| 55 |
+
|
| 56 |
+
return logits
|
| 57 |
+
|
| 58 |
+
def sample_next_token(
|
| 59 |
+
self,
|
| 60 |
+
logits,
|
| 61 |
+
gen_tokens,
|
| 62 |
+
seed_rng,
|
| 63 |
+
temperature=0.8,
|
| 64 |
+
top_k=None,
|
| 65 |
+
top_p=0.9,
|
| 66 |
+
repetition_penalty=1.2,
|
| 67 |
+
frequency_penalty=0.5,
|
| 68 |
+
no_repeat_ngram_size=3,
|
| 69 |
+
recent_tokens_window=200,
|
| 70 |
+
):
|
| 71 |
+
logits = logits.clone().float()
|
| 72 |
+
|
| 73 |
+
recent = gen_tokens[0, -recent_tokens_window:].tolist()
|
| 74 |
+
if repetition_penalty is not None and repetition_penalty != 1.0:
|
| 75 |
+
for t in set(recent):
|
| 76 |
+
if 0 <= t < logits.shape[-1]:
|
| 77 |
+
logits[0, t] /= float(repetition_penalty)
|
| 78 |
+
|
| 79 |
+
logits = self.apply_frequency_penalty_and_blocking(
|
| 80 |
+
logits,
|
| 81 |
+
gen_tokens,
|
| 82 |
+
frequency_penalty=frequency_penalty,
|
| 83 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if temperature is not None and temperature != 1.0:
|
| 87 |
+
logits = logits / float(temperature)
|
| 88 |
+
|
| 89 |
+
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 90 |
+
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
| 91 |
+
|
| 92 |
+
if top_k is not None:
|
| 93 |
+
k = min(int(top_k), sorted_logits.shape[-1])
|
| 94 |
+
sorted_logits = sorted_logits[:, :k]
|
| 95 |
+
sorted_idx = sorted_idx[:, :k]
|
| 96 |
+
sorted_probs = sorted_probs[:, :k]
|
| 97 |
+
|
| 98 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 99 |
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 100 |
+
mask = cum_probs <= top_p
|
| 101 |
+
if not mask.any():
|
| 102 |
+
mask[0, 0] = True
|
| 103 |
+
keep_count = int(mask.sum(dim=-1).item())
|
| 104 |
+
sorted_probs = sorted_probs[:, :keep_count]
|
| 105 |
+
sorted_idx = sorted_idx[:, :keep_count]
|
| 106 |
+
|
| 107 |
+
sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12)
|
| 108 |
+
next_index_in_sorted = torch.multinomial(sorted_probs, 1, generator=seed_rng)
|
| 109 |
+
next_tok = sorted_idx.gather(-1, next_index_in_sorted)
|
| 110 |
+
|
| 111 |
+
return int(next_tok.item())
|
| 112 |
+
|
| 113 |
+
def generate(
|
| 114 |
+
self,
|
| 115 |
+
prompt,
|
| 116 |
+
max_new_tokens=50,
|
| 117 |
+
seed=42,
|
| 118 |
+
longer_story=True,
|
| 119 |
+
temperature=0.8,
|
| 120 |
+
top_k=None,
|
| 121 |
+
top_p=0.9,
|
| 122 |
+
repetition_penalty=1.2,
|
| 123 |
+
frequency_penalty=0.5,
|
| 124 |
+
no_repeat_ngram_size=3,
|
| 125 |
+
context_window=None,
|
| 126 |
+
stream=True,
|
| 127 |
+
):
|
| 128 |
+
self.model.eval()
|
| 129 |
+
|
| 130 |
+
tokens = self.token_encoder.encode(prompt)
|
| 131 |
+
if context_window is not None and len(tokens) > context_window:
|
| 132 |
+
tokens = tokens[-context_window:]
|
| 133 |
+
|
| 134 |
+
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 135 |
+
gen_tokens = tokens.clone()
|
| 136 |
+
|
| 137 |
+
if seed is not None:
|
| 138 |
+
sample_rng = torch.Generator(device=self.device).manual_seed(seed)
|
| 139 |
+
else:
|
| 140 |
+
sample_rng = torch.Generator(device=self.device)
|
| 141 |
+
|
| 142 |
+
eos_id = self.token_encoder.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]
|
| 143 |
+
context_len = self.model.config.context_length
|
| 144 |
+
new_tokens_generated = 0
|
| 145 |
+
HARD_MAX_TOTAL = context_len + max_new_tokens + 10
|
| 146 |
+
|
| 147 |
+
while new_tokens_generated < max_new_tokens and gen_tokens.shape[1] < HARD_MAX_TOTAL:
|
| 148 |
+
if gen_tokens.shape[1] > context_len:
|
| 149 |
+
idx_cond = gen_tokens[:, -context_len:]
|
| 150 |
+
else:
|
| 151 |
+
idx_cond = gen_tokens
|
| 152 |
+
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
try:
|
| 155 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
| 156 |
+
logits, _ = self.model(idx_cond)
|
| 157 |
+
except Exception:
|
| 158 |
+
logits, _ = self.model(idx_cond)
|
| 159 |
+
|
| 160 |
+
next_logits = logits[:, -1:, :].squeeze(1)
|
| 161 |
+
|
| 162 |
+
if longer_story and new_tokens_generated < 5:
|
| 163 |
+
next_logits[0, eos_id] = next_logits[0, eos_id] / 4.0
|
| 164 |
+
|
| 165 |
+
next_token_id = self.sample_next_token(
|
| 166 |
+
logits=next_logits,
|
| 167 |
+
gen_tokens=gen_tokens,
|
| 168 |
+
seed_rng=sample_rng,
|
| 169 |
+
temperature=temperature,
|
| 170 |
+
top_k=top_k,
|
| 171 |
+
top_p=top_p,
|
| 172 |
+
repetition_penalty=repetition_penalty,
|
| 173 |
+
frequency_penalty=frequency_penalty,
|
| 174 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 175 |
+
recent_tokens_window=200,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if next_token_id == eos_id:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
next_tok_tensor = torch.tensor([[next_token_id]], dtype=torch.long).to(self.device)
|
| 182 |
+
gen_tokens = torch.cat([gen_tokens, next_tok_tensor], dim=1)
|
| 183 |
+
new_tokens_generated += 1
|
| 184 |
+
|
| 185 |
+
if stream:
|
| 186 |
+
yield self.token_encoder.decode([next_token_id], errors='ignore')
|
| 187 |
+
|
| 188 |
+
if not stream:
|
| 189 |
+
yield self.token_encoder.decode(gen_tokens[0, :].tolist(), errors='ignore')
|
| 190 |
+
|
| 191 |
+
def print_stream(
|
| 192 |
+
self,
|
| 193 |
+
prompt,
|
| 194 |
+
max_new_tokens=200,
|
| 195 |
+
seed=42,
|
| 196 |
+
longer_story=True,
|
| 197 |
+
temperature=0.8,
|
| 198 |
+
top_k=None,
|
| 199 |
+
top_p=0.9,
|
| 200 |
+
repetition_penalty=1.2,
|
| 201 |
+
frequency_penalty=0.6,
|
| 202 |
+
no_repeat_ngram_size=3,
|
| 203 |
+
context_window=512,
|
| 204 |
+
):
|
| 205 |
+
text = prompt
|
| 206 |
+
last_piece = ""
|
| 207 |
+
print(prompt, end="", flush=True)
|
| 208 |
+
for piece in self.generate(
|
| 209 |
+
prompt,
|
| 210 |
+
max_new_tokens=max_new_tokens,
|
| 211 |
+
seed=seed,
|
| 212 |
+
longer_story=longer_story,
|
| 213 |
+
temperature=temperature,
|
| 214 |
+
top_k=top_k,
|
| 215 |
+
top_p=top_p,
|
| 216 |
+
repetition_penalty=repetition_penalty,
|
| 217 |
+
frequency_penalty=frequency_penalty,
|
| 218 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 219 |
+
context_window=context_window,
|
| 220 |
+
):
|
| 221 |
+
if piece == last_piece:
|
| 222 |
+
continue
|
| 223 |
+
last_piece = piece
|
| 224 |
+
text = concat(text, piece)
|
| 225 |
+
print(piece, end="", flush=True)
|
| 226 |
+
return text
|
src/model.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as f
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import inspect
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class Config:
|
| 9 |
+
context_length : int = 1024
|
| 10 |
+
vocab_size: int = 50257
|
| 11 |
+
num_layers : int = 12
|
| 12 |
+
embedding_dim : int = 768
|
| 13 |
+
num_heads: int = 12
|
| 14 |
+
|
| 15 |
+
class MultiHeadAttention(nn.Module):
|
| 16 |
+
def __init__(self,config : Config,masked=False):
|
| 17 |
+
super(MultiHeadAttention,self).__init__()
|
| 18 |
+
self.num_heads = config.num_heads
|
| 19 |
+
self.masked = masked
|
| 20 |
+
self.embedding_dim = config.embedding_dim
|
| 21 |
+
self.c_attention = nn.Linear(config.embedding_dim,3*config.embedding_dim)
|
| 22 |
+
self.c_projection = nn.Linear(config.embedding_dim,config.embedding_dim)
|
| 23 |
+
self.c_projection.SCALE_INIT = 1.0
|
| 24 |
+
|
| 25 |
+
def forward(self,x):
|
| 26 |
+
B, T, C = x.shape
|
| 27 |
+
QKV = self.c_attention(x)
|
| 28 |
+
Query_q,Key_k,Value_v = QKV.split(self.embedding_dim,dim=-1)
|
| 29 |
+
Query_q = Query_q.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
|
| 30 |
+
Key_k = Key_k.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
|
| 31 |
+
Value_v = Value_v.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2)
|
| 32 |
+
|
| 33 |
+
# out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=True)
|
| 34 |
+
if self.masked:
|
| 35 |
+
out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=True)
|
| 36 |
+
else:
|
| 37 |
+
out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=False)
|
| 38 |
+
out = out.transpose(1,2).contiguous().view(B,T,C)
|
| 39 |
+
return self.c_projection(out)
|
| 40 |
+
|
| 41 |
+
class MLP(nn.Module):
|
| 42 |
+
def __init__(self,config : Config):
|
| 43 |
+
super(MLP,self).__init__()
|
| 44 |
+
self.c_fc = nn.Linear(config.embedding_dim,4*config.embedding_dim)
|
| 45 |
+
self.gelu = nn.GELU(approximate='tanh')
|
| 46 |
+
self.c_projection = nn.Linear(4*config.embedding_dim,config.embedding_dim)
|
| 47 |
+
self.c_projection.SCALE_INIT = 1.0
|
| 48 |
+
def forward(self,x):
|
| 49 |
+
x = self.c_fc(x)
|
| 50 |
+
x = self.gelu(x)
|
| 51 |
+
x = self.c_projection(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
class DecoderBlock(nn.Module):
|
| 55 |
+
def __init__(self,config : Config):
|
| 56 |
+
"""Decoder block without the encoder output"""
|
| 57 |
+
super(DecoderBlock,self).__init__()
|
| 58 |
+
self.masked_attention = MultiHeadAttention(config,masked=True)
|
| 59 |
+
self.layer_norm1 = nn.LayerNorm(config.embedding_dim)
|
| 60 |
+
# self.attention = MultiHeadAttention(config,masked=False)
|
| 61 |
+
# self.layer_norm2 = nn.LayerNorm(config.embedding_dim)
|
| 62 |
+
self.mlp = MLP(config)
|
| 63 |
+
self.layer_norm3 = nn.LayerNorm(config.embedding_dim)
|
| 64 |
+
|
| 65 |
+
def forward(self,x):
|
| 66 |
+
x = x + self.masked_attention(self.layer_norm1(x))
|
| 67 |
+
# x = x + self.attention(self.layer_norm2(x))
|
| 68 |
+
x = x + self.mlp(self.layer_norm3(x))
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
class TransformerDecoder(nn.Module):
|
| 72 |
+
def __init__(self,config : Config):
|
| 73 |
+
super(TransformerDecoder,self).__init__()
|
| 74 |
+
self.config = config
|
| 75 |
+
self.word_token_embedding = nn.Embedding(self.config.vocab_size,self.config.embedding_dim)
|
| 76 |
+
self.word_position_embedding = nn.Embedding(self.config.context_length,self.config.embedding_dim)
|
| 77 |
+
layers = [DecoderBlock(config) for _ in range(config.num_layers)]
|
| 78 |
+
self.hidden_layers = nn.Sequential(*layers)
|
| 79 |
+
self.layer_norm = nn.LayerNorm(self.config.embedding_dim)
|
| 80 |
+
|
| 81 |
+
def forward(self,idx):
|
| 82 |
+
B,T = idx.shape
|
| 83 |
+
pos = torch.arange(0,T,dtype=torch.long,device=idx.device)
|
| 84 |
+
pos_embed = self.word_position_embedding(pos)
|
| 85 |
+
token_embed = self.word_token_embedding(idx)
|
| 86 |
+
x = pos_embed + token_embed
|
| 87 |
+
x = self.hidden_layers(x)
|
| 88 |
+
x = self.layer_norm(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
class GPT(nn.Module):
|
| 92 |
+
def __init__(self,config : Config):
|
| 93 |
+
super(GPT,self).__init__()
|
| 94 |
+
self.config=config
|
| 95 |
+
self.transformerDecoder = TransformerDecoder(config)
|
| 96 |
+
self.language_modeling_head = nn.Linear(config.embedding_dim,config.vocab_size,bias=False)
|
| 97 |
+
self.transformerDecoder.word_token_embedding.weight = self.language_modeling_head.weight
|
| 98 |
+
self.apply(self._init_weights)
|
| 99 |
+
|
| 100 |
+
def _init_weights(self,module):
|
| 101 |
+
if isinstance(module,nn.Linear):
|
| 102 |
+
std=0.02
|
| 103 |
+
if hasattr(module,'SCALE_INIT'):
|
| 104 |
+
std /= (2*self.config.num_layers)**0.5
|
| 105 |
+
torch.nn.init.normal_(module.weight,mean=0,std=std)
|
| 106 |
+
if module.bias is not None:
|
| 107 |
+
torch.nn.init.zeros_(module.bias)
|
| 108 |
+
elif isinstance(module,nn.Embedding):
|
| 109 |
+
torch.nn.init.normal_(module.weight,mean=0,std=0.02)
|
| 110 |
+
|
| 111 |
+
def forward(self,idx,targets=None):
|
| 112 |
+
x = self.transformerDecoder(idx)
|
| 113 |
+
logits = self.language_modeling_head(x)
|
| 114 |
+
loss = None
|
| 115 |
+
if targets is not None:
|
| 116 |
+
loss = f.cross_entropy(logits.view(-1,logits.shape[-1]),targets.view(-1))
|
| 117 |
+
return logits,loss
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def generate(self, idx, max_new_tokens=50, temperature=0.8, top_k=None, do_sample=False, eos_token_id=None):
|
| 120 |
+
self.eval()
|
| 121 |
+
|
| 122 |
+
B, T = idx.shape
|
| 123 |
+
device = idx.device
|
| 124 |
+
context_len = self.config.context_length
|
| 125 |
+
|
| 126 |
+
if T > context_len:
|
| 127 |
+
idx = idx[:, -context_len:]
|
| 128 |
+
T = idx.shape[1]
|
| 129 |
+
|
| 130 |
+
generated = idx.clone()
|
| 131 |
+
|
| 132 |
+
for _ in range(max_new_tokens):
|
| 133 |
+
input_ids = generated[:, -context_len:]
|
| 134 |
+
|
| 135 |
+
logits, _ = self.forward(input_ids, targets=None)
|
| 136 |
+
next_logits = logits[:, -1, :]
|
| 137 |
+
|
| 138 |
+
if temperature != 1.0 and temperature > 0.0:
|
| 139 |
+
next_logits = next_logits / temperature
|
| 140 |
+
|
| 141 |
+
if do_sample:
|
| 142 |
+
if top_k is not None and top_k > 0:
|
| 143 |
+
vals, idxs = next_logits.topk(top_k, dim=-1)
|
| 144 |
+
min_vals = vals[:, -1].unsqueeze(-1)
|
| 145 |
+
mask = next_logits < min_vals
|
| 146 |
+
next_logits = next_logits.masked_fill(mask, float('-inf'))
|
| 147 |
+
|
| 148 |
+
probs = torch.softmax(next_logits, dim=-1)
|
| 149 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 150 |
+
else:
|
| 151 |
+
next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
|
| 152 |
+
|
| 153 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 154 |
+
|
| 155 |
+
if eos_token_id is not None:
|
| 156 |
+
if (generated == eos_token_id).any(dim=1).all():
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
return generated
|
| 160 |
+
def configure_optimizer(self,weight_decay,lr,device_type,master_process):
|
| 161 |
+
param_dict = {pn:p for pn, p in self.named_parameters() if p.requires_grad}
|
| 162 |
+
|
| 163 |
+
decay_params = [p for pn, p in param_dict.items() if p.dim() >=2]
|
| 164 |
+
nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2]
|
| 165 |
+
optim_groups = [
|
| 166 |
+
{'params':decay_params,'weight_decay':weight_decay},
|
| 167 |
+
{'params':nodecay_params,'weight_decay':0.0}
|
| 168 |
+
]
|
| 169 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
| 170 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
| 171 |
+
if master_process:
|
| 172 |
+
print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters')
|
| 173 |
+
print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters')
|
| 174 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
| 175 |
+
use_fused = fused_available and device_type == 'cuda'
|
| 176 |
+
if master_process:
|
| 177 |
+
print(f'using fused AdamW optimizer: {use_fused}')
|
| 178 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
|
| 179 |
+
return optimizer
|
src/trainer.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as f
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import time
|
| 6 |
+
import os
|
| 7 |
+
from src.model import GPT,Config
|
| 8 |
+
from sacrebleu import corpus_bleu
|
| 9 |
+
from rouge_score import rouge_scorer
|
| 10 |
+
import numpy as np
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
|
| 15 |
+
def repetition_rate(text, n=3):
|
| 16 |
+
tokens = text.split()
|
| 17 |
+
if len(tokens) < n:
|
| 18 |
+
return 0.0
|
| 19 |
+
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
|
| 20 |
+
return (len(ngrams) - len(set(ngrams))) / len(ngrams)
|
| 21 |
+
|
| 22 |
+
def distinct_n(text, n=1):
|
| 23 |
+
tokens = text.split()
|
| 24 |
+
if len(tokens) < n:
|
| 25 |
+
return 0.0
|
| 26 |
+
|
| 27 |
+
ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
|
| 28 |
+
return len(set(ngrams)) / len(ngrams)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compute_self_bleu(generated_texts):
|
| 32 |
+
if len(generated_texts) < 2:
|
| 33 |
+
return 0.0
|
| 34 |
+
|
| 35 |
+
scores = []
|
| 36 |
+
N = len(generated_texts)
|
| 37 |
+
|
| 38 |
+
for i in range(N):
|
| 39 |
+
hyp = generated_texts[i]
|
| 40 |
+
refs = generated_texts[:i] + generated_texts[i+1:]
|
| 41 |
+
|
| 42 |
+
bleu = corpus_bleu([hyp], [refs]).score
|
| 43 |
+
scores.append(bleu)
|
| 44 |
+
|
| 45 |
+
return sum(scores) / len(scores)
|
| 46 |
+
|
| 47 |
+
class Trainer:
|
| 48 |
+
def __init__(self,model : GPT,optimizer,train_loader,val_loader,token_encoder,eval_freq,grad_accum_steps,device,master_process,logpath):
|
| 49 |
+
self.model = model
|
| 50 |
+
self.optimizer = optimizer
|
| 51 |
+
self.train_loader = train_loader
|
| 52 |
+
self.val_loader = val_loader
|
| 53 |
+
self.token_encoder = token_encoder
|
| 54 |
+
self.master_process = master_process
|
| 55 |
+
self.eval_freq = eval_freq
|
| 56 |
+
self.grad_accum_steps = grad_accum_steps
|
| 57 |
+
self.device = device
|
| 58 |
+
self.device_type = 'cuda' if device.startswith('cuda') else 'cpu'
|
| 59 |
+
self.logpath=logpath
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def train(self,max_steps,warmup_steps,max_lr,min_lr):
|
| 63 |
+
history={
|
| 64 |
+
'val_losses':[],
|
| 65 |
+
'perplexities':[],
|
| 66 |
+
'train_losses':[]
|
| 67 |
+
}
|
| 68 |
+
for step in range(max_steps):
|
| 69 |
+
val_loss = None
|
| 70 |
+
perplexity=None
|
| 71 |
+
t0 = time.time()
|
| 72 |
+
self.is_last_step = (step == max_steps-1)
|
| 73 |
+
self.model.train()
|
| 74 |
+
self.optimizer.zero_grad()
|
| 75 |
+
batch_loss = 0.0
|
| 76 |
+
|
| 77 |
+
for mini_step in range(self.grad_accum_steps):
|
| 78 |
+
inp, target = self.train_loader.next_batch()
|
| 79 |
+
inp, target = inp.to(self.device),target.to(self.device)
|
| 80 |
+
|
| 81 |
+
with torch.autocast(device_type=self.device_type,dtype=torch.bfloat16):
|
| 82 |
+
logits,loss = self.model(inp,target)
|
| 83 |
+
loss /=self.grad_accum_steps
|
| 84 |
+
batch_loss+=loss.detach()
|
| 85 |
+
loss.backward()
|
| 86 |
+
norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 87 |
+
lr = self.estimate_lr(step,warmup_steps,max_steps,max_lr,min_lr)
|
| 88 |
+
for param_group in self.optimizer.param_groups:
|
| 89 |
+
param_group['lr'] = lr
|
| 90 |
+
self.optimizer.step()
|
| 91 |
+
if self.device_type == 'cuda':
|
| 92 |
+
torch.cuda.synchronize()
|
| 93 |
+
dt = (time.time() - t0) * 1000.0 # in ms
|
| 94 |
+
tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * 1
|
| 95 |
+
tokens_per_sec = tokens_processed / dt
|
| 96 |
+
|
| 97 |
+
if step % self.eval_freq == 0 or self.is_last_step:
|
| 98 |
+
val_loss,perplexity = self.evaluate_validation(step)
|
| 99 |
+
history['val_losses'].append(val_loss)
|
| 100 |
+
history['perplexities'].append(perplexity)
|
| 101 |
+
|
| 102 |
+
history['train_losses'].append(batch_loss.item())
|
| 103 |
+
if self.master_process:
|
| 104 |
+
print(f'step {step:4d} | train loss: {batch_loss.item():.2f}{f' | val loss: {val_loss:.2f}' if val_loss is not None else ''}{f' | perplexity: {perplexity:.2f}' if perplexity is not None else ''} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}')
|
| 105 |
+
with open(self.logpath, 'a') as f:
|
| 106 |
+
f.write(f'{step} train {batch_loss.item():.6f}\n')
|
| 107 |
+
|
| 108 |
+
evaluation =self.evaluate_text_metrics(
|
| 109 |
+
max_samples=60,
|
| 110 |
+
gen_len=256,
|
| 111 |
+
do_sample=False,
|
| 112 |
+
top_k=None,
|
| 113 |
+
temperature=0.2,
|
| 114 |
+
eos_token_id=None
|
| 115 |
+
)
|
| 116 |
+
return history,evaluation
|
| 117 |
+
|
| 118 |
+
def evaluate_text_metrics(self, max_samples=100, gen_len=50, do_sample=False, top_k=None, temperature=1.0, eos_token_id=None):
|
| 119 |
+
self.model.eval()
|
| 120 |
+
self.val_loader.reset()
|
| 121 |
+
|
| 122 |
+
hyps = []
|
| 123 |
+
refs = []
|
| 124 |
+
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
|
| 125 |
+
|
| 126 |
+
samples_collected = 0
|
| 127 |
+
while samples_collected < max_samples:
|
| 128 |
+
try:
|
| 129 |
+
inp, target = self.val_loader.next_batch()
|
| 130 |
+
except StopIteration:
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
inp = inp.to(self.device)
|
| 134 |
+
target = target.to(self.device)
|
| 135 |
+
|
| 136 |
+
if inp.shape[1] > self.model.config.context_length:
|
| 137 |
+
inp = inp[:, -self.model.config.context_length:]
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
generated = self.model.generate(
|
| 141 |
+
inp,
|
| 142 |
+
max_new_tokens=gen_len,
|
| 143 |
+
temperature=temperature,
|
| 144 |
+
top_k=top_k,
|
| 145 |
+
do_sample=do_sample,
|
| 146 |
+
eos_token_id=eos_token_id
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
B = generated.shape[0]
|
| 150 |
+
for i in range(B):
|
| 151 |
+
gen_ids = generated[i, inp.shape[1]:].tolist()
|
| 152 |
+
|
| 153 |
+
pred_text = self.token_encoder.decode(gen_ids)
|
| 154 |
+
ref_text = self.token_encoder.decode(target[i].tolist())
|
| 155 |
+
|
| 156 |
+
hyps.append(pred_text)
|
| 157 |
+
refs.append(ref_text)
|
| 158 |
+
samples_collected += 1
|
| 159 |
+
if samples_collected >= max_samples:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
if len(hyps) == 0:
|
| 163 |
+
return 0.0, 0.0
|
| 164 |
+
|
| 165 |
+
rep_scores = []
|
| 166 |
+
distinct1_scores = []
|
| 167 |
+
distinct2_scores = []
|
| 168 |
+
|
| 169 |
+
for txt in hyps:
|
| 170 |
+
rep_scores.append(repetition_rate(txt, n=3))
|
| 171 |
+
distinct1_scores.append(distinct_n(txt, n=1))
|
| 172 |
+
distinct2_scores.append(distinct_n(txt, n=2))
|
| 173 |
+
|
| 174 |
+
avg_rep = sum(rep_scores) / len(rep_scores)
|
| 175 |
+
avg_d1 = sum(distinct1_scores) / len(distinct1_scores)
|
| 176 |
+
avg_d2 = sum(distinct2_scores) / len(distinct2_scores)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
bleu = corpus_bleu(hyps, refs).score
|
| 180 |
+
self_bleu=compute_self_bleu(hyps)
|
| 181 |
+
rouge_scores = []
|
| 182 |
+
for h, r in zip(hyps, refs):
|
| 183 |
+
sc = scorer.score(r, h)['rougeL'].fmeasure
|
| 184 |
+
rouge_scores.append(sc)
|
| 185 |
+
rouge_l = sum(rouge_scores) / len(rouge_scores)
|
| 186 |
+
|
| 187 |
+
if self.master_process:
|
| 188 |
+
print(f"[Text Eval] samples={len(hyps)} BLEU={bleu:.2f} ROUGE-L={rouge_l:.4f} SELF-BLEU={self_bleu:.2f} REP={avg_rep:.4f} D1={avg_d1:.4f} D2={avg_d2:.4f}")
|
| 189 |
+
with open(self.logpath, 'a') as f:
|
| 190 |
+
f.write(f"eval samples={len(hyps)} BLEU={bleu:.2f} ROUGE-L={rouge_l:.4f} SELF-BLEU={self_bleu:.2f} REP={avg_rep:.4f} D1={avg_d1:.4f} D2={avg_d2:.4f}\n")
|
| 191 |
+
|
| 192 |
+
return {"bleu":bleu,"rogue-l":rouge_scores,"self-bleu":self_bleu,"repetition":rep_scores,"D1":distinct1_scores,"D2":distinct2_scores}
|
| 193 |
+
|
| 194 |
+
def evaluate_validation(self,step):
|
| 195 |
+
self.model.eval()
|
| 196 |
+
self.val_loader.reset()
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
val_loss_accum = 0.0
|
| 199 |
+
val_steps = 20
|
| 200 |
+
for _ in range(val_steps):
|
| 201 |
+
inp, target = self.val_loader.next_batch()
|
| 202 |
+
inp, target = inp.to(self.device),target.to(self.device)
|
| 203 |
+
|
| 204 |
+
with torch.autocast(device_type=self.device_type,dtype=torch.bfloat16):
|
| 205 |
+
logits,loss = self.model(inp,target)
|
| 206 |
+
loss /=val_steps
|
| 207 |
+
val_loss_accum+=loss.detach()
|
| 208 |
+
|
| 209 |
+
if self.master_process:
|
| 210 |
+
perplexity = math.exp(val_loss_accum.item())
|
| 211 |
+
with open(self.logpath, 'a') as f:
|
| 212 |
+
f.write(f'{step} val {val_loss_accum.item():.4f}\n')
|
| 213 |
+
|
| 214 |
+
if step > 0 and (step % 10000 == 0 or self.is_last_step):
|
| 215 |
+
raw_model = self.model
|
| 216 |
+
logdir = os.path.dirname(self.logpath)
|
| 217 |
+
ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt')
|
| 218 |
+
checkpoint = {
|
| 219 |
+
'model': raw_model.state_dict(),
|
| 220 |
+
'config': raw_model.config,
|
| 221 |
+
'step': step,
|
| 222 |
+
'val_loss': val_loss_accum.item()
|
| 223 |
+
}
|
| 224 |
+
torch.save(checkpoint, ckpt_path)
|
| 225 |
+
return val_loss_accum.item(),perplexity
|
| 226 |
+
|
| 227 |
+
def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr):
|
| 228 |
+
if step < warmup_steps:
|
| 229 |
+
return max_lr * (step+1) / warmup_steps
|
| 230 |
+
if step > max_steps:
|
| 231 |
+
return min_lr
|
| 232 |
+
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
|
| 233 |
+
assert 0 <= decay_ratio <= 1
|
| 234 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
| 235 |
+
return min_lr + coeff * (max_lr - min_lr)
|