vasiuuu commited on
Commit
2f07379
·
verified ·
1 Parent(s): 85d0ad5

new commits

Browse files
Files changed (5) hide show
  1. accuracy-plot.pdf +0 -0
  2. gpt_download.py +157 -0
  3. loss-plot.pdf +0 -0
  4. main.ipynb +0 -0
  5. transformer.py +320 -0
accuracy-plot.pdf ADDED
Binary file (14.1 kB). View file
 
gpt_download.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
+ # Source for "Build a Large Language Model From Scratch"
3
+ # - https://www.manning.com/books/build-a-large-language-model-from-scratch
4
+ # Code: https://github.com/rasbt/LLMs-from-scratch
5
+
6
+
7
+ import os
8
+ import urllib.request
9
+
10
+ # import requests
11
+ import json
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ from tqdm import tqdm
15
+
16
+
17
+ def download_and_load_gpt2(model_size, models_dir):
18
+ # Validate model size
19
+ allowed_sizes = ("124M", "355M", "774M", "1558M")
20
+ if model_size not in allowed_sizes:
21
+ raise ValueError(f"Model size not in {allowed_sizes}")
22
+
23
+ # Define paths
24
+ model_dir = os.path.join(models_dir, model_size)
25
+ base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
26
+ backup_base_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2"
27
+ filenames = [
28
+ "checkpoint", "encoder.json", "hparams.json",
29
+ "model.ckpt.data-00000-of-00001", "model.ckpt.index",
30
+ "model.ckpt.meta", "vocab.bpe"
31
+ ]
32
+
33
+ # Download files
34
+ os.makedirs(model_dir, exist_ok=True)
35
+ for filename in filenames:
36
+ file_url = os.path.join(base_url, model_size, filename)
37
+ backup_url = os.path.join(backup_base_url, model_size, filename)
38
+ file_path = os.path.join(model_dir, filename)
39
+ download_file(file_url, file_path, backup_url)
40
+
41
+ # Load settings and params
42
+ tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
43
+ settings = json.load(open(os.path.join(model_dir, "hparams.json"), "r", encoding="utf-8"))
44
+ params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
45
+
46
+ return settings, params
47
+
48
+
49
+ def download_file(url, destination, backup_url=None):
50
+ def _attempt_download(download_url):
51
+ with urllib.request.urlopen(download_url) as response:
52
+ # Get the total file size from headers, defaulting to 0 if not present
53
+ file_size = int(response.headers.get("Content-Length", 0))
54
+
55
+ # Check if file exists and has the same size
56
+ if os.path.exists(destination):
57
+ file_size_local = os.path.getsize(destination)
58
+ if file_size == file_size_local:
59
+ print(f"File already exists and is up-to-date: {destination}")
60
+ return True # Indicate success without re-downloading
61
+
62
+ block_size = 1024 # 1 Kilobyte
63
+
64
+ # Initialize the progress bar with total file size
65
+ progress_bar_description = os.path.basename(download_url)
66
+ with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
67
+ with open(destination, "wb") as file:
68
+ while True:
69
+ chunk = response.read(block_size)
70
+ if not chunk:
71
+ break
72
+ file.write(chunk)
73
+ progress_bar.update(len(chunk))
74
+ return True
75
+
76
+ try:
77
+ if _attempt_download(url):
78
+ return
79
+ except (urllib.error.HTTPError, urllib.error.URLError):
80
+ if backup_url is not None:
81
+ print(f"Primary URL ({url}) failed. Attempting backup URL: {backup_url}")
82
+ try:
83
+ if _attempt_download(backup_url):
84
+ return
85
+ except urllib.error.HTTPError:
86
+ pass
87
+
88
+ # If we reach here, both attempts have failed
89
+ error_message = (
90
+ f"Failed to download from both primary URL ({url})"
91
+ f"{' and backup URL (' + backup_url + ')' if backup_url else ''}."
92
+ "\nCheck your internet connection or the file availability.\n"
93
+ "For help, visit: https://github.com/rasbt/LLMs-from-scratch/discussions/273"
94
+ )
95
+ print(error_message)
96
+ except Exception as e:
97
+ print(f"An unexpected error occurred: {e}")
98
+
99
+
100
+ # Alternative way using `requests`
101
+ """
102
+ def download_file(url, destination):
103
+ # Send a GET request to download the file in streaming mode
104
+ response = requests.get(url, stream=True)
105
+
106
+ # Get the total file size from headers, defaulting to 0 if not present
107
+ file_size = int(response.headers.get("content-length", 0))
108
+
109
+ # Check if file exists and has the same size
110
+ if os.path.exists(destination):
111
+ file_size_local = os.path.getsize(destination)
112
+ if file_size == file_size_local:
113
+ print(f"File already exists and is up-to-date: {destination}")
114
+ return
115
+
116
+ # Define the block size for reading the file
117
+ block_size = 1024 # 1 Kilobyte
118
+
119
+ # Initialize the progress bar with total file size
120
+ progress_bar_description = url.split("/")[-1] # Extract filename from URL
121
+ with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
122
+ # Open the destination file in binary write mode
123
+ with open(destination, "wb") as file:
124
+ # Iterate over the file data in chunks
125
+ for chunk in response.iter_content(block_size):
126
+ progress_bar.update(len(chunk)) # Update progress bar
127
+ file.write(chunk) # Write the chunk to the file
128
+ """
129
+
130
+
131
+ def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
132
+ # Initialize parameters dictionary with empty blocks for each layer
133
+ params = {"blocks": [{} for _ in range(settings["n_layer"])]}
134
+
135
+ # Iterate over each variable in the checkpoint
136
+ for name, _ in tf.train.list_variables(ckpt_path):
137
+ # Load the variable and remove singleton dimensions
138
+ variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
139
+
140
+ # Process the variable name to extract relevant parts
141
+ variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
142
+
143
+ # Identify the target dictionary for the variable
144
+ target_dict = params
145
+ if variable_name_parts[0].startswith("h"):
146
+ layer_number = int(variable_name_parts[0][1:])
147
+ target_dict = params["blocks"][layer_number]
148
+
149
+ # Recursively access or create nested dictionaries
150
+ for key in variable_name_parts[1:-1]:
151
+ target_dict = target_dict.setdefault(key, {})
152
+
153
+ # Assign the variable array to the last key
154
+ last_key = variable_name_parts[-1]
155
+ target_dict[last_key] = variable_array
156
+
157
+ return params
loss-plot.pdf ADDED
Binary file (12 kB). View file
 
main.ipynb ADDED
File without changes
transformer.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
+ # Source for "Build a Large Language Model From Scratch"
3
+ # - https://www.manning.com/books/build-a-large-language-model-from-scratch
4
+ # Code: https://github.com/rasbt/LLMs-from-scratch
5
+ #
6
+ # This file collects all the relevant code that we covered thus far
7
+ # throughout Chapters 2-5.
8
+ # This file can be run as a standalone script.
9
+
10
+ import numpy as np
11
+ import tiktoken
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+ #####################################
17
+ # Chapter 2
18
+ #####################################
19
+
20
+
21
+ class GPTDatasetV1(Dataset):
22
+ def __init__(self, txt, tokenizer, max_length, stride):
23
+ self.input_ids = []
24
+ self.target_ids = []
25
+
26
+ # Tokenize the entire text
27
+ token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
28
+
29
+ # Use a sliding window to chunk the book into overlapping sequences of max_length
30
+ for i in range(0, len(token_ids) - max_length, stride):
31
+ input_chunk = token_ids[i:i + max_length]
32
+ target_chunk = token_ids[i + 1: i + max_length + 1]
33
+ self.input_ids.append(torch.tensor(input_chunk))
34
+ self.target_ids.append(torch.tensor(target_chunk))
35
+
36
+ def __len__(self):
37
+ return len(self.input_ids)
38
+
39
+ def __getitem__(self, idx):
40
+ return self.input_ids[idx], self.target_ids[idx]
41
+
42
+
43
+ def create_dataloader_v1(txt, batch_size=4, max_length=256,
44
+ stride=128, shuffle=True, drop_last=True, num_workers=0):
45
+ # Initialize the tokenizer
46
+ tokenizer = tiktoken.get_encoding("gpt2")
47
+
48
+ # Create dataset
49
+ dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
50
+
51
+ # Create dataloader
52
+ dataloader = DataLoader(
53
+ dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
54
+
55
+ return dataloader
56
+
57
+
58
+ #####################################
59
+ # Chapter 3
60
+ #####################################
61
+ class MultiHeadAttention(nn.Module):
62
+ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
63
+ super().__init__()
64
+ assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
65
+
66
+ self.d_out = d_out
67
+ self.num_heads = num_heads
68
+ self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
69
+
70
+ self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
71
+ self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
72
+ self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
73
+ self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
74
+ self.dropout = nn.Dropout(dropout)
75
+ self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
76
+
77
+ def forward(self, x):
78
+ b, num_tokens, d_in = x.shape
79
+
80
+ keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
81
+ queries = self.W_query(x)
82
+ values = self.W_value(x)
83
+
84
+ # We implicitly split the matrix by adding a `num_heads` dimension
85
+ # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
86
+ keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
87
+ values = values.view(b, num_tokens, self.num_heads, self.head_dim)
88
+ queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
89
+
90
+ # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
91
+ keys = keys.transpose(1, 2)
92
+ queries = queries.transpose(1, 2)
93
+ values = values.transpose(1, 2)
94
+
95
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
96
+ attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
97
+
98
+ # Original mask truncated to the number of tokens and converted to boolean
99
+ mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
100
+
101
+ # Use the mask to fill attention scores
102
+ attn_scores.masked_fill_(mask_bool, -torch.inf)
103
+
104
+ attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
105
+ attn_weights = self.dropout(attn_weights)
106
+
107
+ # Shape: (b, num_tokens, num_heads, head_dim)
108
+ context_vec = (attn_weights @ values).transpose(1, 2)
109
+
110
+ # Combine heads, where self.d_out = self.num_heads * self.head_dim
111
+ context_vec = context_vec.reshape(b, num_tokens, self.d_out)
112
+ context_vec = self.out_proj(context_vec) # optional projection
113
+
114
+ return context_vec
115
+
116
+
117
+ #####################################
118
+ # Chapter 4
119
+ #####################################
120
+ class LayerNorm(nn.Module):
121
+ def __init__(self, emb_dim):
122
+ super().__init__()
123
+ self.eps = 1e-5
124
+ self.scale = nn.Parameter(torch.ones(emb_dim))
125
+ self.shift = nn.Parameter(torch.zeros(emb_dim))
126
+
127
+ def forward(self, x):
128
+ mean = x.mean(dim=-1, keepdim=True)
129
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
130
+ norm_x = (x - mean) / torch.sqrt(var + self.eps)
131
+ return self.scale * norm_x + self.shift
132
+
133
+
134
+ class GELU(nn.Module):
135
+ def __init__(self):
136
+ super().__init__()
137
+
138
+ def forward(self, x):
139
+ return 0.5 * x * (1 + torch.tanh(
140
+ torch.sqrt(torch.tensor(2.0 / torch.pi)) *
141
+ (x + 0.044715 * torch.pow(x, 3))
142
+ ))
143
+
144
+
145
+ class FeedForward(nn.Module):
146
+ def __init__(self, cfg):
147
+ super().__init__()
148
+ self.layers = nn.Sequential(
149
+ nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
150
+ GELU(),
151
+ nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
152
+ )
153
+
154
+ def forward(self, x):
155
+ return self.layers(x)
156
+
157
+
158
+ class TransformerBlock(nn.Module):
159
+ def __init__(self, cfg):
160
+ super().__init__()
161
+ self.att = MultiHeadAttention(
162
+ d_in=cfg["emb_dim"],
163
+ d_out=cfg["emb_dim"],
164
+ context_length=cfg["context_length"],
165
+ num_heads=cfg["n_heads"],
166
+ dropout=cfg["drop_rate"],
167
+ qkv_bias=cfg["qkv_bias"])
168
+ self.ff = FeedForward(cfg)
169
+ self.norm1 = LayerNorm(cfg["emb_dim"])
170
+ self.norm2 = LayerNorm(cfg["emb_dim"])
171
+ self.drop_resid = nn.Dropout(cfg["drop_rate"])
172
+
173
+ def forward(self, x):
174
+ # Shortcut connection for attention block
175
+ shortcut = x
176
+ x = self.norm1(x)
177
+ x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
178
+ x = self.drop_resid(x)
179
+ x = x + shortcut # Add the original input back
180
+
181
+ # Shortcut connection for feed-forward block
182
+ shortcut = x
183
+ x = self.norm2(x)
184
+ x = self.ff(x)
185
+ x = self.drop_resid(x)
186
+ x = x + shortcut # Add the original input back
187
+
188
+ return x
189
+
190
+
191
+ class GPTModel(nn.Module):
192
+ def __init__(self, cfg):
193
+ super().__init__()
194
+ self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
195
+ self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
196
+ self.drop_emb = nn.Dropout(cfg["drop_rate"])
197
+
198
+ self.trf_blocks = nn.Sequential(
199
+ *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
200
+
201
+ self.final_norm = LayerNorm(cfg["emb_dim"])
202
+ self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
203
+
204
+ def forward(self, in_idx):
205
+ batch_size, seq_len = in_idx.shape
206
+ tok_embeds = self.tok_emb(in_idx)
207
+ pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
208
+ x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
209
+ x = self.drop_emb(x)
210
+ x = self.trf_blocks(x)
211
+ x = self.final_norm(x)
212
+ logits = self.out_head(x)
213
+ return logits
214
+
215
+
216
+ def generate_text_simple(model, idx, max_new_tokens, context_size):
217
+ # idx is (B, T) array of indices in the current context
218
+ for _ in range(max_new_tokens):
219
+
220
+ # Crop current context if it exceeds the supported context size
221
+ # E.g., if LLM supports only 5 tokens, and the context size is 10
222
+ # then only the last 5 tokens are used as context
223
+ idx_cond = idx[:, -context_size:]
224
+
225
+ # Get the predictions
226
+ with torch.no_grad():
227
+ logits = model(idx_cond)
228
+
229
+ # Focus only on the last time step
230
+ # (batch, n_token, vocab_size) becomes (batch, vocab_size)
231
+ logits = logits[:, -1, :]
232
+
233
+ # Get the idx of the vocab entry with the highest logits value
234
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
235
+
236
+ # Append sampled index to the running sequence
237
+ idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
238
+
239
+ return idx
240
+
241
+
242
+ #####################################
243
+ # Chapter 5
244
+ #####################################
245
+ def assign(left, right):
246
+ if left.shape != right.shape:
247
+ raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}")
248
+ return torch.nn.Parameter(torch.tensor(right))
249
+
250
+
251
+ def load_weights_into_gpt(gpt, params):
252
+ gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
253
+ gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
254
+
255
+ for b in range(len(params["blocks"])):
256
+ q_w, k_w, v_w = np.split(
257
+ (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
258
+ gpt.trf_blocks[b].att.W_query.weight = assign(
259
+ gpt.trf_blocks[b].att.W_query.weight, q_w.T)
260
+ gpt.trf_blocks[b].att.W_key.weight = assign(
261
+ gpt.trf_blocks[b].att.W_key.weight, k_w.T)
262
+ gpt.trf_blocks[b].att.W_value.weight = assign(
263
+ gpt.trf_blocks[b].att.W_value.weight, v_w.T)
264
+
265
+ q_b, k_b, v_b = np.split(
266
+ (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
267
+ gpt.trf_blocks[b].att.W_query.bias = assign(
268
+ gpt.trf_blocks[b].att.W_query.bias, q_b)
269
+ gpt.trf_blocks[b].att.W_key.bias = assign(
270
+ gpt.trf_blocks[b].att.W_key.bias, k_b)
271
+ gpt.trf_blocks[b].att.W_value.bias = assign(
272
+ gpt.trf_blocks[b].att.W_value.bias, v_b)
273
+
274
+ gpt.trf_blocks[b].att.out_proj.weight = assign(
275
+ gpt.trf_blocks[b].att.out_proj.weight,
276
+ params["blocks"][b]["attn"]["c_proj"]["w"].T)
277
+ gpt.trf_blocks[b].att.out_proj.bias = assign(
278
+ gpt.trf_blocks[b].att.out_proj.bias,
279
+ params["blocks"][b]["attn"]["c_proj"]["b"])
280
+
281
+ gpt.trf_blocks[b].ff.layers[0].weight = assign(
282
+ gpt.trf_blocks[b].ff.layers[0].weight,
283
+ params["blocks"][b]["mlp"]["c_fc"]["w"].T)
284
+ gpt.trf_blocks[b].ff.layers[0].bias = assign(
285
+ gpt.trf_blocks[b].ff.layers[0].bias,
286
+ params["blocks"][b]["mlp"]["c_fc"]["b"])
287
+ gpt.trf_blocks[b].ff.layers[2].weight = assign(
288
+ gpt.trf_blocks[b].ff.layers[2].weight,
289
+ params["blocks"][b]["mlp"]["c_proj"]["w"].T)
290
+ gpt.trf_blocks[b].ff.layers[2].bias = assign(
291
+ gpt.trf_blocks[b].ff.layers[2].bias,
292
+ params["blocks"][b]["mlp"]["c_proj"]["b"])
293
+
294
+ gpt.trf_blocks[b].norm1.scale = assign(
295
+ gpt.trf_blocks[b].norm1.scale,
296
+ params["blocks"][b]["ln_1"]["g"])
297
+ gpt.trf_blocks[b].norm1.shift = assign(
298
+ gpt.trf_blocks[b].norm1.shift,
299
+ params["blocks"][b]["ln_1"]["b"])
300
+ gpt.trf_blocks[b].norm2.scale = assign(
301
+ gpt.trf_blocks[b].norm2.scale,
302
+ params["blocks"][b]["ln_2"]["g"])
303
+ gpt.trf_blocks[b].norm2.shift = assign(
304
+ gpt.trf_blocks[b].norm2.shift,
305
+ params["blocks"][b]["ln_2"]["b"])
306
+
307
+ gpt.final_norm.scale = assign(gpt.final_norm.scale, params["g"])
308
+ gpt.final_norm.shift = assign(gpt.final_norm.shift, params["b"])
309
+ gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
310
+
311
+
312
+ def text_to_token_ids(text, tokenizer):
313
+ encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
314
+ encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
315
+ return encoded_tensor
316
+
317
+
318
+ def token_ids_to_text(token_ids, tokenizer):
319
+ flat = token_ids.squeeze(0) # remove batch dimension
320
+ return tokenizer.decode(flat.tolist())