prism-lab commited on
Commit
3cdf87b
·
verified ·
1 Parent(s): 48f6f9f

Upload AIAYN_Baseline_Training.ipynb

Browse files

WMT14 Transformer Baseline training code. BLEU calculation for eval is not standard and misleading, on paper we used torchmetrics SacreBLEU.

Files changed (1) hide show
  1. AIAYN_Baseline_Training.ipynb +872 -0
AIAYN_Baseline_Training.ipynb ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "id": "2s48Vmoo9EB5"
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "!pip install -q torchmetrics sacrebleu"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "Lz8buKsjvA_w"
19
+ },
20
+ "source": [
21
+ "## CONFIG"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "df355sdDrNSb"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "# --- Data & Task Size ---\n",
33
+ "MAX_LENGTH = 128\n",
34
+ "\n",
35
+ "MODEL_CHOICE = \"Baseline\" # For save path\n",
36
+ "\n",
37
+ "# --- Model Architecture Config (\"Transformer-Small\") ---\n",
38
+ "D_MODEL = 512\n",
39
+ "NUM_HEADS = 8\n",
40
+ "D_FF = 2048\n",
41
+ "DROPOUT = 0.1\n",
42
+ "\n",
43
+ "# --- Layer counts ---\n",
44
+ "NUM_ENCODER_LAYERS = 6\n",
45
+ "NUM_DECODER_LAYERS = 6\n",
46
+ "\n",
47
+ "# --- Training Config ---\n",
48
+ "TARGET_TRAINING_STEPS = 50000\n",
49
+ "\n",
50
+ "VALIDATION_SCHEDULE = [\n",
51
+ " 2000, 4000, 5000, 7500, 10000, 15000, 20000,\n",
52
+ " 25000, 30000, 35000, 42500, 50000\n",
53
+ "]\n",
54
+ "PEAK_LEARNING_RATE = 8e-4\n",
55
+ "WARMUP_STEPS = 120 # This is a flex, Kaiming + Pre-LN + AdamW is so stable that we don't even need warmups\n",
56
+ "WEIGHT_DECAY = 0.01\n",
57
+ "\n",
58
+ "# --- Regularization Config ---\n",
59
+ "LABEL_SMOOTHING_EPSILON = 0.1\n",
60
+ "\n",
61
+ "# --- Other Constants ---\n",
62
+ "DRIVE_BASE_PATH = \"/content/drive/MyDrive/AIAYN\"\n",
63
+ "PREBATCHED_REPO_ID = \"prism-lab/wmt14-de-en-prebatched-w4\"\n",
64
+ "ORIGINAL_BUCKETED_REPO_ID = \"prism-lab/wmt14-de-en-bucketed-w4\"\n",
65
+ "MODEL_CHECKPOINT = \"Helsinki-NLP/opus-mt-de-en\" # We only use its tokenizer\n"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "source": [
71
+ "## DATALOADERS"
72
+ ],
73
+ "metadata": {
74
+ "id": "W5l1HHRFXxPA"
75
+ }
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {
81
+ "id": "FA5SqFzeMrpK"
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "\n",
86
+ "import torch\n",
87
+ "import torch.nn as nn\n",
88
+ "from torch.utils.data import DataLoader\n",
89
+ "from transformers import AutoTokenizer\n",
90
+ "from datasets import load_dataset\n",
91
+ "import math\n",
92
+ "import os\n",
93
+ "from tqdm.auto import tqdm\n",
94
+ "from torchmetrics.text import BLEUScore\n",
95
+ "from torch.utils.tensorboard import SummaryWriter\n",
96
+ "import random\n",
97
+ "import numpy as np\n",
98
+ "import torch\n",
99
+ "from transformers import get_cosine_schedule_with_warmup\n",
100
+ "from typing import List\n",
101
+ "from transformers import AutoModel\n",
102
+ "\n",
103
+ "\n",
104
+ "def set_seed(seed_value=5):\n",
105
+ " \"\"\"Sets the seed for reproducibility.\"\"\"\n",
106
+ " random.seed(seed_value)\n",
107
+ " np.random.seed(seed_value)\n",
108
+ " torch.manual_seed(seed_value)\n",
109
+ " torch.cuda.manual_seed_all(seed_value)\n",
110
+ " torch.backends.cudnn.deterministic = True\n",
111
+ " torch.backends.cudnn.benchmark = False\n",
112
+ "\n",
113
+ "SEED = 116\n",
114
+ "set_seed(SEED)\n",
115
+ "print(f\"Reproducibility seed set to {SEED}\")\n",
116
+ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
117
+ "\n",
118
+ "torch.use_deterministic_algorithms(True)\n",
119
+ "\n",
120
+ "print(\"--- Loading Modernized Configuration ---\")\n",
121
+ "def seed_worker(worker_id):\n",
122
+ " \"\"\"\n",
123
+ " DataLoader worker'ları için seed ayarlama fonksiyonu.\n",
124
+ " Her worker'ın farklı ama deterministik bir seed'e sahip olmasını sağlar.\n",
125
+ " \"\"\"\n",
126
+ " worker_seed = torch.initial_seed() % 2**32\n",
127
+ " np.random.seed(worker_seed)\n",
128
+ " random.seed(worker_seed)\n",
129
+ "\n",
130
+ "torch.set_float32_matmul_precision('high')\n",
131
+ "print(\"✅ PyTorch matmul precision set to 'high'\")\n",
132
+ "\n",
133
+ "# --- Device Setup ---\n",
134
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
135
+ "print(f\"Using device: {device}\")\n",
136
+ "\n",
137
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n",
138
+ "\n",
139
+ "VOCAB_SIZE = len(tokenizer)\n",
140
+ "print(f\"Vocab size: {VOCAB_SIZE}\")\n",
141
+ "\n",
142
+ "\n",
143
+ "# DATA LOADING & PREPARATION\n",
144
+ "from transformers import DataCollatorForSeq2Seq\n",
145
+ "\n",
146
+ "standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)\n",
147
+ "\n",
148
+ "class PreBatchedCollator:\n",
149
+ " def __init__(self, original_dataset_split):\n",
150
+ " self.original_dataset = original_dataset_split\n",
151
+ "\n",
152
+ " def __call__(self, features: List[dict]) -> dict:\n",
153
+ " # 'features' will be a list of size 1, e.g., [{'batch_indices': [10, 5, 123]}]\n",
154
+ " batch_indices = features[0]['batch_indices']\n",
155
+ "\n",
156
+ " # This returns a \"Dictionary of Lists\"\n",
157
+ " # e.g., {'input_ids': [[...], [...]], 'labels': [[...], [...]]}\n",
158
+ " dict_of_lists = self.original_dataset[batch_indices]\n",
159
+ "\n",
160
+ " # --- THE FIX ---\n",
161
+ " # We must convert it to a \"List of Dictionaries\" for the standard collator.\n",
162
+ " # e.g., [{'input_ids': [...], 'labels': [...]}, {'input_ids': [...], 'labels': [...]}]\n",
163
+ " list_of_dicts = []\n",
164
+ " keys = dict_of_lists.keys()\n",
165
+ " num_samples = len(dict_of_lists['input_ids'])\n",
166
+ "\n",
167
+ " for i in range(num_samples):\n",
168
+ " list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})\n",
169
+ " # --- END OF FIX ---\n",
170
+ "\n",
171
+ " # Now, pass the correctly formatted data to the standard collator\n",
172
+ " return standard_collator(list_of_dicts)\n",
173
+ "\n",
174
+ "print(f\"Loading pre-batched dataset from: {PREBATCHED_REPO_ID}\")\n",
175
+ "prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)\n",
176
+ "\n",
177
+ "print(f\"Loading original samples from: {ORIGINAL_BUCKETED_REPO_ID}\")\n",
178
+ "original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)\n",
179
+ "train_collator = PreBatchedCollator(original_datasets[\"train\"])\n",
180
+ "\n",
181
+ "# --- The New, Simple DataLoader ---\n",
182
+ "# No more custom sampler!\n",
183
+ "g = torch.Generator()\n",
184
+ "g.manual_seed(SEED)\n",
185
+ "\n",
186
+ "train_dataloader = DataLoader(\n",
187
+ " prebatched_datasets[\"train\"],\n",
188
+ " batch_size=1, # Each row is already a batch\n",
189
+ " shuffle=True, # Shuffle the pre-calculated batches every epoch\n",
190
+ " num_workers=0,\n",
191
+ " collate_fn=train_collator,\n",
192
+ " pin_memory=True,\n",
193
+ " worker_init_fn=seed_worker,\n",
194
+ " generator=g,\n",
195
+ ")\n",
196
+ "\n",
197
+ "# Validation loader remains the same, using the original data\n",
198
+ "EVAL_BATCH_SIZE = 64\n",
199
+ "val_dataloader = DataLoader(\n",
200
+ " original_datasets[\"validation\"],\n",
201
+ " batch_size=EVAL_BATCH_SIZE,\n",
202
+ " collate_fn=standard_collator,\n",
203
+ " num_workers=0,\n",
204
+ " pin_memory=True,\n",
205
+ " worker_init_fn=seed_worker,\n",
206
+ " generator=g,\n",
207
+ ")\n",
208
+ "\n",
209
+ "print(f\"Train Dataloader is now a simple iterator over pre-calculated batches.\")\n",
210
+ "\n",
211
+ "# --- SANITY CHECK ---\n",
212
+ "print(\"\\n--- Running Sanity Check on new DataLoader ---\")\n",
213
+ "train_dataloader.generator.manual_seed(SEED) # Reset generator for check\n",
214
+ "temp_iterator = iter(train_dataloader)\n",
215
+ "print(\"Shapes of first 5 batches:\")\n",
216
+ "for i in range(5):\n",
217
+ " batch = next(temp_iterator)\n",
218
+ " print(f\" Batch {i+1}: input_ids shape = {batch['input_ids'].shape}\")\n",
219
+ "print(\"--- Sanity Check Complete ---\\n\")"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {
225
+ "id": "cS4JvJGRhClv"
226
+ },
227
+ "source": [
228
+ "## Models"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {
235
+ "id": "SMhlM0YvO1A7"
236
+ },
237
+ "outputs": [],
238
+ "source": [
239
+ "import torch\n",
240
+ "import torch.nn as nn\n",
241
+ "import torch.nn.functional as F\n",
242
+ "import math\n",
243
+ "\n",
244
+ "class PositionalEncoding(nn.Module):\n",
245
+ " \"\"\"Injects positional information into the input embeddings.\"\"\"\n",
246
+ " def __init__(self, d_model: int, max_len: int = 5000):\n",
247
+ " super().__init__()\n",
248
+ " position = torch.arange(max_len).unsqueeze(1)\n",
249
+ " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
250
+ " pe = torch.zeros(1, max_len, d_model)\n",
251
+ " pe[0, :, 0::2] = torch.sin(position * div_term)\n",
252
+ " pe[0, :, 1::2] = torch.cos(position * div_term)\n",
253
+ " self.register_buffer('pe', pe)\n",
254
+ "\n",
255
+ " def forward(self, x: torch.Tensor):\n",
256
+ " # x shape: [batch_size, seq_len, d_model]\n",
257
+ " return x + self.pe[:, :x.size(1)]\n",
258
+ "\n",
259
+ "class FeedForward(nn.Module):\n",
260
+ " \"\"\"A standard two-layer feed-forward network with a ReLU activation.\"\"\"\n",
261
+ " def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):\n",
262
+ " super().__init__()\n",
263
+ " self.ffn = nn.Sequential(\n",
264
+ " nn.Linear(d_model, dff),\n",
265
+ " nn.ReLU(),\n",
266
+ " nn.Linear(dff, d_model),\n",
267
+ " nn.Dropout(dropout_rate)\n",
268
+ " )\n",
269
+ " def forward(self, x: torch.Tensor):\n",
270
+ " return self.ffn(x)\n",
271
+ "\n",
272
+ "class StandardTransformer(nn.Module):\n",
273
+ " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n",
274
+ " super().__init__()\n",
275
+ " self.d_model = d_model\n",
276
+ " self.embedding = nn.Embedding(vocab_size, d_model)\n",
277
+ " self.pos_encoder = PositionalEncoding(d_model, max_length)\n",
278
+ " self.dropout = nn.Dropout(dropout)\n",
279
+ " encoder_layer = nn.TransformerEncoderLayer(\n",
280
+ " d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX\n",
281
+ " )\n",
282
+ " self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)\n",
283
+ "\n",
284
+ " decoder_layer = nn.TransformerDecoderLayer(\n",
285
+ " d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX\n",
286
+ " )\n",
287
+ " self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)\n",
288
+ "\n",
289
+ " self.final_linear = nn.Linear(d_model, vocab_size)\n",
290
+ " self.final_linear.weight = self.embedding.weight\n",
291
+ "\n",
292
+ " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n",
293
+ "\n",
294
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
295
+ " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n",
296
+ " src_emb_pos = self.dropout(self.pos_encoder(src_emb))\n",
297
+ " tgt_emb_pos = self.dropout(self.pos_encoder(tgt_emb))\n",
298
+ "\n",
299
+ " memory = self.encoder(src_emb_pos, src_key_padding_mask=src_padding_mask)\n",
300
+ " decoder_output = self.decoder(\n",
301
+ " tgt=tgt_emb_pos, memory=memory, tgt_mask=tgt_mask,\n",
302
+ " tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask\n",
303
+ " )\n",
304
+ " return self.final_linear(decoder_output)\n",
305
+ "\n",
306
+ "\n",
307
+ " def create_masks(self, src, tgt):\n",
308
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
309
+ " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n",
310
+ " # Creates a square causal mask for the decoder. This prevents any token from attending to future tokens. With this way model can not cheat.\n",
311
+ " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n",
312
+ " sz=tgt.size(1),\n",
313
+ " device=src.device,\n",
314
+ " dtype=torch.bool\n",
315
+ " )\n",
316
+ " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n",
317
+ "\n",
318
+ " @torch.no_grad()\n",
319
+ " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n",
320
+ " self.eval()\n",
321
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
322
+ "\n",
323
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
324
+ " src_emb_pos = self.pos_encoder(src_emb)\n",
325
+ " memory = self.encoder(self.dropout(src_emb_pos), src_key_padding_mask=src_padding_mask)\n",
326
+ "\n",
327
+ " batch_size = src.shape[0]\n",
328
+ " memory = memory.repeat_interleave(num_beams, dim=0)\n",
329
+ " memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)\n",
330
+ "\n",
331
+ " initial_token = tokenizer.pad_token_id\n",
332
+ " beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n",
333
+ "\n",
334
+ " beam_scores = torch.zeros(batch_size * num_beams, device=src.device)\n",
335
+ " finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)\n",
336
+ " for _ in range(max_length - 1):\n",
337
+ " if finished_beams.all(): break\n",
338
+ " tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)\n",
339
+ " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) # FIX HERE TOO\n",
340
+ " tgt_emb_pos = self.pos_encoder(tgt_emb)\n",
341
+ " decoder_output = self.decoder(tgt=self.dropout(tgt_emb_pos), memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)\n",
342
+ " logits = self.final_linear(decoder_output[:, -1, :])\n",
343
+ " log_probs = F.log_softmax(logits, dim=-1)\n",
344
+ " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n",
345
+ " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n",
346
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
347
+ " if _ == 0:\n",
348
+ " total_scores = total_scores.view(batch_size, num_beams, -1)\n",
349
+ " total_scores[:, 1:, :] = -torch.inf # Sadece ilk beam'in başlamasına izin ver\n",
350
+ " total_scores = total_scores.view(batch_size * num_beams, -1)\n",
351
+ " else:\n",
352
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
353
+ " total_scores = total_scores.view(batch_size, -1)\n",
354
+ " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n",
355
+ " beam_indices = top_indices // log_probs.shape[-1]; token_indices = top_indices % log_probs.shape[-1]\n",
356
+ " batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)\n",
357
+ " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n",
358
+ " beams = beams[effective_indices]\n",
359
+ " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n",
360
+ " beam_scores = top_scores.view(-1)\n",
361
+ " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n",
362
+ " final_beams = beams.view(batch_size, num_beams, -1)\n",
363
+ " final_scores = beam_scores.view(batch_size, num_beams)\n",
364
+ " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n",
365
+ " best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]\n",
366
+ " self.train()\n",
367
+ " return best_beams\n"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {
374
+ "id": "3QGBtTvj6Jrp"
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "# ==============================================================================\n",
379
+ "# --- Model Analysis & Parameter Counting ---\n",
380
+ "# ==============================================================================\n",
381
+ "from collections import defaultdict\n",
382
+ "\n",
383
+ "def count_parameters_correctly(model):\n",
384
+ " \"\"\"\n",
385
+ " Counts trainable parameters, correctly handling tied weights (e.g., embeddings).\n",
386
+ " \"\"\"\n",
387
+ " seen_params = set()\n",
388
+ " total_params = 0\n",
389
+ " for param in model.parameters():\n",
390
+ " if param.requires_grad:\n",
391
+ " param_id = id(param)\n",
392
+ " if param_id not in seen_params:\n",
393
+ " seen_params.add(param_id)\n",
394
+ " total_params += param.numel()\n",
395
+ " return total_params\n",
396
+ "\n",
397
+ "# --- Instantiate the model to analyze it ---\n",
398
+ "print(\"--- Analyzing Model Parameters ---\")\n",
399
+ "model_to_analyze = StandardTransformer(\n",
400
+ " num_encoder_layers=NUM_ENCODER_LAYERS,\n",
401
+ " num_decoder_layers=NUM_DECODER_LAYERS,\n",
402
+ " num_heads=NUM_HEADS,\n",
403
+ " d_model=D_MODEL,\n",
404
+ " dff=D_FF,\n",
405
+ " vocab_size=VOCAB_SIZE,\n",
406
+ " max_length=MAX_LENGTH,\n",
407
+ " dropout=DROPOUT\n",
408
+ ")\n",
409
+ "\n",
410
+ "# --- Perform the counting and display results ---\n",
411
+ "correct_total = count_parameters_correctly(model_to_analyze)\n",
412
+ "pytorch_naive_total = sum(p.numel() for p in model_to_analyze.parameters() if p.requires_grad)\n",
413
+ "\n",
414
+ "print(f\"Total Trainable Parameters (Correctly Counted): {correct_total:,}\")\n",
415
+ "print(f\"PyTorch's Naive Count (sum(p.numel())): {pytorch_naive_total:,}\")\n",
416
+ "if pytorch_naive_total != correct_total:\n",
417
+ " print(f\"Note: The naive count is higher due to double-counting the tied embedding weights.\")\n",
418
+ "\n",
419
+ "del model_to_analyze # Clean up memory\n",
420
+ "print(\"--- Analysis Complete ---\\n\")"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "markdown",
425
+ "metadata": {
426
+ "id": "Zd3AFTmhrCJq"
427
+ },
428
+ "source": [
429
+ "## Functions (Loss, Eval etc)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": null,
435
+ "metadata": {
436
+ "id": "Te1qTyUKrDEd"
437
+ },
438
+ "outputs": [],
439
+ "source": [
440
+ "\n",
441
+ "translation_loss_fn = nn.CrossEntropyLoss(\n",
442
+ " ignore_index=-100, # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.\n",
443
+ " label_smoothing=LABEL_SMOOTHING_EPSILON\n",
444
+ ")\n",
445
+ "def calculate_combined_loss(model_outputs, target_labels):\n",
446
+ " \"\"\"Calculates the loss based on the model's output structure.\"\"\"\n",
447
+ " logits = model_outputs\n",
448
+ " translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))\n",
449
+ " loss_dict = {'total': translation_loss.item()}\n",
450
+ " return translation_loss, loss_dict\n",
451
+ "\n",
452
+ "def evaluate(model, dataloader, device):\n",
453
+ " \"\"\"Evaluates the model using beam search decoding.\"\"\"\n",
454
+ " bleu_metric = BLEUScore()\n",
455
+ "\n",
456
+ "\n",
457
+ " orig_model = getattr(model, '_orig_mod', model)\n",
458
+ " orig_model.eval()\n",
459
+ "\n",
460
+ " for batch in tqdm(dataloader, desc=\"Evaluating\", leave=False):\n",
461
+ " input_ids = batch['input_ids'].to(device)\n",
462
+ " labels = batch['labels']\n",
463
+ "\n",
464
+ " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n",
465
+ "\n",
466
+ " pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
467
+ " labels[labels == -100] = tokenizer.pad_token_id\n",
468
+ " ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
469
+ " bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])\n",
470
+ "\n",
471
+ " orig_model.train()\n",
472
+ " return bleu_metric.compute().item()\n",
473
+ "\n",
474
+ "def generate_sample_translations(model, device, sentences_de):\n",
475
+ " \"\"\"Generates and prints sample translations using beam search.\"\"\"\n",
476
+ " print(\"\\n--- Generating Sample Translations (with Beam Search) ---\")\n",
477
+ " orig_model = getattr(model, '_orig_mod', model)\n",
478
+ " orig_model.eval()\n",
479
+ "\n",
480
+ " inputs = tokenizer(sentences_de, return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_LENGTH)\n",
481
+ " input_ids = inputs.input_ids.to(device)\n",
482
+ " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n",
483
+ "\n",
484
+ " translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
485
+ " for src, out in zip(sentences_de, translations):\n",
486
+ " print(f\" DE Source: {src}\")\n",
487
+ " print(f\" EN Output: {out}\")\n",
488
+ " print(\"-\" * 20)\n",
489
+ " orig_model.train()\n",
490
+ "\n",
491
+ "sample_sentences_de_for_tracking = [\n",
492
+ " \"Eine Katze sitzt auf der Matte.\",\n",
493
+ " \"Ein Mann in einem roten Hemd liest ein Buch.\",\n",
494
+ " \"Was ist die Hauptstadt von Deutschland?\",\n",
495
+ " \"Ich gehe ins Kino, weil der Film sehr gut ist.\",\n",
496
+ "]\n",
497
+ "\n",
498
+ "def init_other_linear_weights(m):\n",
499
+ " if isinstance(m, nn.Linear):\n",
500
+ " # The 'is not' check correctly skips the final_linear layer,\n",
501
+ " # leaving its weights tied to the correctly initialized embeddings.\n",
502
+ " if m is not getattr(model, '_orig_mod', model).final_linear:\n",
503
+ " nn.init.xavier_uniform_(m.weight)\n",
504
+ " if m.bias is not None:\n",
505
+ " nn.init.zeros_(m.bias)\n",
506
+ "\n",
507
+ "\n"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "source": [
513
+ "import json\n",
514
+ "import os\n",
515
+ "import subprocess\n",
516
+ "import torch\n",
517
+ "import hashlib\n",
518
+ "import sys\n",
519
+ "import shutil\n",
520
+ "\n",
521
+ "# This logger will be configured and used in the main training script\n",
522
+ "import logging\n",
523
+ "logger = logging.getLogger(__name__)\n",
524
+ "\n",
525
+ "\n",
526
+ "def log_to_run_specific_file(run_dir):\n",
527
+ " run_log_path = os.path.join(run_dir, \"run_log.txt\")\n",
528
+ " file_handler = logging.FileHandler(run_log_path)\n",
529
+ " file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))\n",
530
+ " logger.addHandler(file_handler)\n",
531
+ " return file_handler\n",
532
+ "\n",
533
+ "def log_configurations(log_dir, config_vars):\n",
534
+ " # (Same as your provided function)\n",
535
+ " config_path = os.path.join(log_dir, \"config.json\")\n",
536
+ " try:\n",
537
+ " with open(config_path, 'w') as f:\n",
538
+ " serializable_configs = {k: v for k, v in config_vars.items() if isinstance(v, (int, float, str, bool, list, dict, type(None)))}\n",
539
+ " json.dump(serializable_configs, f, indent=4)\n",
540
+ " logger.info(f\"Configurations saved to {config_path}\")\n",
541
+ " except Exception as e:\n",
542
+ " logger.error(f\"Could not save configurations: {e}\")\n",
543
+ "\n",
544
+ "def log_environment(log_dir):\n",
545
+ " # (Same as your provided function)\n",
546
+ " env_path = os.path.join(log_dir, \"environment.txt\")\n",
547
+ " try:\n",
548
+ " with open(env_path, 'w') as f:\n",
549
+ " f.write(f\"--- Timestamp (UTC): {datetime.datetime.utcnow().isoformat()} ---\\n\")\n",
550
+ " f.write(f\"Python Version: {sys.version}\\n\")\n",
551
+ " f.write(f\"PyTorch Version: {torch.__version__}\\n\")\n",
552
+ " f.write(f\"CUDA Available: {torch.cuda.is_available()}\\n\")\n",
553
+ " if torch.cuda.is_available():\n",
554
+ " f.write(f\"CUDA Version: {torch.version.cuda}\\n\")\n",
555
+ " f.write(f\"CuDNN Version: {torch.backends.cudnn.version()}\\n\")\n",
556
+ " f.write(f\"Number of GPUs: {torch.cuda.device_count()}\\n\")\n",
557
+ " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n",
558
+ " f.write(\"\\n--- Full pip freeze ---\\n\")\n",
559
+ " result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'], stdout=subprocess.PIPE, text=True, check=True)\n",
560
+ " f.write(result.stdout)\n",
561
+ " logger.info(f\"Environment info saved to {env_path}\")\n",
562
+ " except Exception as e:\n",
563
+ " logger.error(f\"Could not save environment info: {e}\")\n",
564
+ "\n",
565
+ "def log_code_snapshot(log_dir, script_path):\n",
566
+ " # NOTE: In Colab, you must save your notebook as a .py file for this to work.\n",
567
+ " # For example, file -> \"Save a copy as .py\"\n",
568
+ " code_dir = os.path.join(log_dir, \"code_snapshot\")\n",
569
+ " os.makedirs(code_dir, exist_ok=True)\n",
570
+ " if script_path and os.path.exists(script_path):\n",
571
+ " try:\n",
572
+ " shutil.copy(script_path, os.path.join(code_dir, os.path.basename(script_path)))\n",
573
+ " logger.info(f\"Copied script '{script_path}' to snapshot directory for verification.\")\n",
574
+ " except Exception as e:\n",
575
+ " logger.error(f\"Could not copy script for snapshot: {e}\")\n",
576
+ " else:\n",
577
+ " logger.warning(f\"Code Snapshot: Script path '{script_path}' not found. SKIPPING.\")\n",
578
+ "\n",
579
+ "def get_file_hash(filepath):\n",
580
+ " # (Same as your provided function)\n",
581
+ " sha256_hash = hashlib.sha256()\n",
582
+ " try:\n",
583
+ " with open(filepath, \"rb\") as f:\n",
584
+ " for byte_block in iter(lambda: f.read(4096), b\"\"):\n",
585
+ " sha256_hash.update(byte_block)\n",
586
+ " return sha256_hash.hexdigest()\n",
587
+ " except Exception as e:\n",
588
+ " logger.error(f\"Could not generate hash for {filepath}: {e}\")\n",
589
+ " return None\n",
590
+ "\n",
591
+ "def create_checksum_file(run_dir, artifacts_dict):\n",
592
+ " checksum_file_path = os.path.join(run_dir, \"checksums.sha256\")\n",
593
+ " logger.info(f\"--- Creating digital fingerprints for key artifacts ---\")\n",
594
+ " with open(checksum_file_path, \"w\") as f:\n",
595
+ " f.write(f\"SHA256 Checksums for run: {os.path.basename(run_dir)}\\n\")\n",
596
+ " for name, path in artifacts_dict.items():\n",
597
+ " if path and os.path.exists(path):\n",
598
+ " file_hash = get_file_hash(path)\n",
599
+ " if file_hash:\n",
600
+ " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n",
601
+ " logger.info(log_message)\n",
602
+ " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n",
603
+ " else:\n",
604
+ " logger.warning(f\" - Skipped hashing '{name}', file not found: {path}\")\n",
605
+ " logger.info(f\"Checksums saved to {checksum_file_path}\")\n",
606
+ "\n",
607
+ "def init_weights_kaiming(m):\n",
608
+ " \"\"\"\n",
609
+ " Applies Kaiming He initialization to Linear layers.\n",
610
+ " This is the standard, superior way to initialize deep Transformers.\n",
611
+ " NOTE: We will handle the Embedding layer separately.\n",
612
+ " \"\"\"\n",
613
+ "\n",
614
+ " if isinstance(m, nn.Linear):\n",
615
+ " nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) # a=sqrt(5) mimics default PyTorch for LeakyReLU\n",
616
+ " if m.bias is not None:\n",
617
+ " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)\n",
618
+ " bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
619
+ " nn.init.uniform_(m.bias, -bound, bound)\n"
620
+ ],
621
+ "metadata": {
622
+ "id": "YwPXbSwR50I2"
623
+ },
624
+ "execution_count": null,
625
+ "outputs": []
626
+ },
627
+ {
628
+ "cell_type": "markdown",
629
+ "metadata": {
630
+ "id": "ijTUk5dHu494"
631
+ },
632
+ "source": [
633
+ "## Training Loop"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": null,
639
+ "metadata": {
640
+ "id": "pyHZ1moluyA2"
641
+ },
642
+ "outputs": [],
643
+ "source": [
644
+ "\n",
645
+ "if __name__ == '__main__':\n",
646
+ "\n",
647
+ " experiment_name = f\"{MODEL_CHOICE}\"\n",
648
+ " CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name) # Single run directory\n",
649
+ " SAVE_DIR = os.path.join(CURRENT_RUN_DIR, \"models\")\n",
650
+ " LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, \"tensorboard_logs\")\n",
651
+ " LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, \"run_log.txt\")\n",
652
+ "\n",
653
+ " os.makedirs(SAVE_DIR, exist_ok=True)\n",
654
+ " os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)\n",
655
+ "\n",
656
+ " logging.basicConfig(\n",
657
+ " level=logging.INFO,\n",
658
+ " format='%(asctime)s [%(levelname)s] %(message)s',\n",
659
+ " handlers=[\n",
660
+ " logging.FileHandler(LOG_FILE_TXT),\n",
661
+ " logging.StreamHandler(sys.stdout)\n",
662
+ " ],\n",
663
+ " force=True\n",
664
+ " )\n",
665
+ " logger = logging.getLogger(__name__)\n",
666
+ " writer = SummaryWriter(LOG_DIR_TENSORBOARD)\n",
667
+ "\n",
668
+ " logger.info(f\"--- LAUNCHING EXPERIMENT: {experiment_name} ---\")\n",
669
+ "\n",
670
+ " all_configs = {k: v for k, v in globals().items() if k.isupper()}\n",
671
+ " all_configs['TARGET_TRAINING_STEPS'] = TARGET_TRAINING_STEPS\n",
672
+ " all_configs['VALIDATION_SCHEDULE'] = VALIDATION_SCHEDULE\n",
673
+ " log_configurations(CURRENT_RUN_DIR, all_configs)\n",
674
+ " log_environment(CURRENT_RUN_DIR)\n",
675
+ " log_code_snapshot(CURRENT_RUN_DIR, \"your_notebook_name.ipynb\") # Remember to update this filename\n",
676
+ "\n",
677
+ " set_seed(SEED)\n",
678
+ " logger.info(f\"Reproducibility seed set to {SEED}\")\n",
679
+ "\n",
680
+ " logger.info(f\"--- Initializing StandardTransformer ---\")\n",
681
+ " model = StandardTransformer(\n",
682
+ " num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS,\n",
683
+ " num_heads=NUM_HEADS, d_model=D_MODEL, dff=D_FF, vocab_size=VOCAB_SIZE,\n",
684
+ " max_length=MAX_LENGTH, dropout=DROPOUT\n",
685
+ " )\n",
686
+ "\n",
687
+ " # 3. WEIGHT INITIALIZATION STRATEGY\n",
688
+ " model.apply(init_weights_kaiming)\n",
689
+ " logger.info(\" Applied Kaiming Uniform initialization to all linear layers.\")\n",
690
+ "\n",
691
+ " # Removed the if/else logic, only the \"from-scratch\" path remains\n",
692
+ " logger.info(\"--- Initializing embedding layer from scratch ---\")\n",
693
+ " nn.init.normal_(model.embedding.weight, mean=0.0, std=0.02)\n",
694
+ " logger.info(\" Initialized embedding map with Normal(0, 0.02).\")\n",
695
+ "\n",
696
+ " # Tie weights AFTER all initialization is complete.\n",
697
+ " model.final_linear.weight = model.embedding.weight\n",
698
+ "\n",
699
+ " model.to(device)\n",
700
+ " logger.info(f\"Model is ready on {device}.\")\n",
701
+ "\n",
702
+ " # 4. SETUP OPTIMIZER, SCHEDULER, AND SCALER\n",
703
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, betas=(0.9, 0.98),\n",
704
+ " eps=1e-9, weight_decay=WEIGHT_DECAY)\n",
705
+ " scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=WARMUP_STEPS,\n",
706
+ " num_training_steps=TARGET_TRAINING_STEPS) # Use total steps\n",
707
+ " scaler = torch.cuda.amp.GradScaler()\n",
708
+ "\n",
709
+ " # 5. TRAINING LOOP\n",
710
+ " model.train()\n",
711
+ " global_step = 0 # Renamed from global_step_this_iteration\n",
712
+ " best_bleu = 0.0 # Renamed from best_bleu_this_iteration\n",
713
+ " LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"last.pt\")\n",
714
+ " BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"best.pt\")\n",
715
+ "\n",
716
+ " # Simplified progress bar\n",
717
+ " progress_bar = tqdm(total=TARGET_TRAINING_STEPS, desc=\"Total Progress\")\n",
718
+ " training_complete = False\n",
719
+ "\n",
720
+ " for epoch in range(200): # This can be a large number, the step check will stop it\n",
721
+ " if training_complete: break\n",
722
+ "\n",
723
+ " # --- Simplified generator seed ---\n",
724
+ " train_dataloader.generator.manual_seed(SEED + epoch)\n",
725
+ "\n",
726
+ " for batch in train_dataloader:\n",
727
+ " if global_step >= TARGET_TRAINING_STEPS: # Check against total steps\n",
728
+ " training_complete = True\n",
729
+ " break\n",
730
+ "\n",
731
+ " optimizer.zero_grad(set_to_none=True)\n",
732
+ " input_ids = batch['input_ids'].to(device, non_blocking=True)\n",
733
+ " labels = batch['labels'].to(device, non_blocking=True)\n",
734
+ " decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)\n",
735
+ " decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)\n",
736
+ " decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id\n",
737
+ " target_labels = labels\n",
738
+ "\n",
739
+ " src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)\n",
740
+ " tgt_padding_mask[:, 0] = False\n",
741
+ "\n",
742
+ " with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n",
743
+ " model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask,\n",
744
+ " tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask,\n",
745
+ " tgt_mask=tgt_mask)\n",
746
+ " loss, loss_components = calculate_combined_loss(model_outputs, target_labels)\n",
747
+ "\n",
748
+ " scaler.scale(loss).backward()\n",
749
+ " scaler.unscale_(optimizer)\n",
750
+ " total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
751
+ " scaler.step(optimizer)\n",
752
+ " scaler.update()\n",
753
+ " scheduler.step()\n",
754
+ " global_step += 1 # Use main global_step\n",
755
+ " progress_bar.update(1)\n",
756
+ " lr = scheduler.get_last_lr()[0]\n",
757
+ "\n",
758
+ " if global_step % 20 == 0:\n",
759
+ " writer.add_scalar('train/loss', loss.item(), global_step)\n",
760
+ " writer.add_scalar('train/learning_rate', lr, global_step)\n",
761
+ " writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step)\n",
762
+ " progress_bar.set_postfix(loss=loss.item(), grad_norm=f\"{total_grad_norm.item():.2f}\", lr=f\"{lr:.2e}\")\n",
763
+ "\n",
764
+ " if global_step in VALIDATION_SCHEDULE:\n",
765
+ " # --- Simplified logging message ---\n",
766
+ " logger.info(f\"\\n--- Validation at Step {global_step} ---\")\n",
767
+ " bleu_score = evaluate(model, val_dataloader, device)\n",
768
+ " writer.add_scalar('validation/bleu', bleu_score, global_step)\n",
769
+ " logger.info(f\"Validation BLEU: {bleu_score:.4f} (Best: {best_bleu:.4f})\")\n",
770
+ " generate_sample_translations(model, device, sample_sentences_de_for_tracking)\n",
771
+ "\n",
772
+ " if bleu_score > best_bleu:\n",
773
+ " best_bleu = bleu_score\n",
774
+ " logger.info(f\" New best BLEU! Saving best model...\")\n",
775
+ " torch.save(model.state_dict(), BEST_CHECKPOINT_PATH)\n",
776
+ "\n",
777
+ " model.train()\n",
778
+ "\n",
779
+ " progress_bar.close()\n",
780
+ " writer.close()\n",
781
+ " logger.info(f\"--- Training finished after {global_step} steps ---\")\n",
782
+ "\n",
783
+ " # --- 6. SAVE FINAL STATE ---\n",
784
+ " torch.save({'global_step': global_step, 'model_state_dict': model.state_dict()},\n",
785
+ " LAST_CHECKPOINT_PATH)\n",
786
+ " logger.info(f\"Saved final state to: {LAST_CHECKPOINT_PATH}\")\n",
787
+ "\n",
788
+ " # --- Removed the previous_iteration_checkpoint_path line ---\n",
789
+ "\n",
790
+ " # --- 7. CREATE DIGITAL FINGERPRINTS ---\n",
791
+ " logger.info(\"--- Creating digital fingerprints for key artifacts ---\")\n",
792
+ " files_to_hash = {\n",
793
+ " \"Last Model\": LAST_CHECKPOINT_PATH,\n",
794
+ " \"Best Model\": BEST_CHECKPOINT_PATH,\n",
795
+ " \"Text Log\": LOG_FILE_TXT,\n",
796
+ " }\n",
797
+ "\n",
798
+ " try:\n",
799
+ " tb_log_file = [f for f in os.listdir(LOG_DIR_TENSORBOARD) if 'tfevents' in f][0]\n",
800
+ " files_to_hash[\"TensorBoard Log\"] = os.path.join(LOG_DIR_TENSORBOARD, tb_log_file)\n",
801
+ " except IndexError:\n",
802
+ " logger.warning(\"Could not find TensorBoard events file to hash.\")\n",
803
+ "\n",
804
+ " checksum_file_path = os.path.join(CURRENT_RUN_DIR, \"checksums.sha256\")\n",
805
+ " with open(checksum_file_path, \"w\") as f:\n",
806
+ " # --- Simplified checksums file ---\n",
807
+ " f.write(f\"SHA256 Checksums for run: {experiment_name}\\n\")\n",
808
+ " f.write(\"=\"*50 + \"\\n\")\n",
809
+ " for name, path in files_to_hash.items():\n",
810
+ " if path and os.path.exists(path):\n",
811
+ " file_hash = get_file_hash(path)\n",
812
+ " if file_hash:\n",
813
+ " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n",
814
+ " logger.info(log_message)\n",
815
+ " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n",
816
+ " else:\n",
817
+ " logger.warning(f\" - Skipped hashing for '{name}', file not found: {path}\")\n",
818
+ "\n",
819
+ " logger.info(f\"Checksums saved to {checksum_file_path}\")\n",
820
+ "\n",
821
+ " print(\"\\n\\n\" + \"*\"*80)\n",
822
+ " print(\" EXPERIMENT COMPLETE \")\n",
823
+ " print(\"*\"*80)"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": null,
829
+ "metadata": {
830
+ "id": "tqDiOyy18clU"
831
+ },
832
+ "outputs": [],
833
+ "source": [
834
+ "# TENSORBOARD VISUALIZATION\n",
835
+ "\n",
836
+ "%load_ext tensorboard\n",
837
+ "\n",
838
+ "TENSORBOARD_BASE_DIR = os.path.join(DRIVE_BASE_PATH)\n",
839
+ "\n",
840
+ "%tensorboard --logdir \"{TENSORBOARD_BASE_DIR}\""
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "markdown",
845
+ "metadata": {
846
+ "id": "eI0-qVlWVVpx"
847
+ },
848
+ "source": [
849
+ "## End"
850
+ ]
851
+ }
852
+ ],
853
+ "metadata": {
854
+ "accelerator": "GPU",
855
+ "colab": {
856
+ "gpuType": "A100",
857
+ "provenance": [],
858
+ "collapsed_sections": [
859
+ "cS4JvJGRhClv"
860
+ ]
861
+ },
862
+ "kernelspec": {
863
+ "display_name": "Python 3",
864
+ "name": "python3"
865
+ },
866
+ "language_info": {
867
+ "name": "python"
868
+ }
869
+ },
870
+ "nbformat": 4,
871
+ "nbformat_minor": 0
872
+ }