prism-lab commited on
Commit
d6e8aa8
·
verified ·
1 Parent(s): 3cdf87b

Upload 2 files

Browse files

FNet and PRISM WMT14 training scripts.

FNet_Train_Last.ipynb ADDED
@@ -0,0 +1,1286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "collapsed": true,
8
+ "id": "2s48Vmoo9EB5"
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "!pip install -q torchmetrics sacrebleu x-transformers"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "Lz8buKsjvA_w"
19
+ },
20
+ "source": [
21
+ "## CONFIG"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "df355sdDrNSb"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "!pip install -q torchmetrics sacrebleu x-transformers\n",
33
+ "\n",
34
+ "## CONFIG\n",
35
+ "\n",
36
+ "# --- Data & Task Size ---\n",
37
+ "MAX_LENGTH = 128\n",
38
+ "\n",
39
+ "MODEL_CHOICE = \"Name_Your_Model\" # Renamed for clarity\n",
40
+ "\n",
41
+ "# --- Model Architecture Config ---\n",
42
+ "D_MODEL = 512\n",
43
+ "NUM_HEADS = 8\n",
44
+ "D_FF = 2048\n",
45
+ "DROPOUT = 0.1\n",
46
+ "\n",
47
+ "# --- Layer counts ---\n",
48
+ "NUM_ENCODER_LAYERS = 7\n",
49
+ "NUM_DECODER_LAYERS = 6\n",
50
+ "\n",
51
+ "# --- Training Config (ADJUSTED FOR FAIR COMPARISON) ---\n",
52
+ "\n",
53
+ "TARGET_TRAINING_STEPS = 100000\n",
54
+ "GRAD_ACCUMULATION_STEPS = 2\n",
55
+ "\n",
56
+ "\n",
57
+ "VALIDATION_SCHEDULE = [\n",
58
+ " 2000, 4000, 5000, 7500, 10000, 15000, 20000,\n",
59
+ " 25000, 30000, 35000, 42500, 50000, 57500, 65000, 72500, 90000, 100000\n",
60
+ "]\n",
61
+ "PEAK_LEARNING_RATE = 6e-4\n",
62
+ "WARMUP_STEPS = 600 # Warmup can stay similar or scale slightly, 600 is fine\n",
63
+ "WEIGHT_DECAY = 0.01\n",
64
+ "\n",
65
+ "# --- Regularization Config ---\n",
66
+ "LABEL_SMOOTHING_EPSILON = 0.1\n",
67
+ "\n",
68
+ "# --- Other Constants ---\n",
69
+ "DRIVE_BASE_PATH = \"/content/drive/MyDrive/AIAYN\"\n",
70
+ "ORIGINAL_BUCKETED_REPO_ID = \"prism-lab/wmt14-de-en-bucketed-w4\" # Use the bucketed one (we will ignore buckets)\n",
71
+ "MODEL_CHECKPOINT = \"Helsinki-NLP/opus-mt-de-en\""
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {
77
+ "id": "W5l1HHRFXxPA"
78
+ },
79
+ "source": [
80
+ "## DATALOADERS"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "metadata": {
87
+ "collapsed": true,
88
+ "id": "FA5SqFzeMrpK"
89
+ },
90
+ "outputs": [],
91
+ "source": [
92
+ "import torch\n",
93
+ "import torch.nn as nn\n",
94
+ "from torch.utils.data import DataLoader\n",
95
+ "from transformers import AutoTokenizer\n",
96
+ "from datasets import load_dataset\n",
97
+ "import math\n",
98
+ "import os\n",
99
+ "from tqdm.auto import tqdm\n",
100
+ "from torch.utils.tensorboard import SummaryWriter\n",
101
+ "import random\n",
102
+ "import numpy as np\n",
103
+ "import torch\n",
104
+ "from transformers import get_cosine_schedule_with_warmup\n",
105
+ "from typing import List\n",
106
+ "from transformers import AutoModel\n",
107
+ "from transformers import DataCollatorForSeq2Seq\n",
108
+ "\n",
109
+ "\n",
110
+ "def set_seed(seed_value=5):\n",
111
+ " \"\"\"Sets the seed for reproducibility.\"\"\"\n",
112
+ " random.seed(seed_value)\n",
113
+ " np.random.seed(seed_value)\n",
114
+ " torch.manual_seed(seed_value)\n",
115
+ " torch.cuda.manual_seed_all(seed_value)\n",
116
+ " torch.backends.cudnn.deterministic = True\n",
117
+ " torch.backends.cudnn.benchmark = False\n",
118
+ "\n",
119
+ "SEED = 117\n",
120
+ "set_seed(SEED)\n",
121
+ "print(f\"Reproducibility seed set to {SEED}\")\n",
122
+ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
123
+ "\n",
124
+ "#torch.use_deterministic_algorithms(True)\n",
125
+ "\n",
126
+ "print(\"--- Loading Modernized Configuration ---\")\n",
127
+ "def seed_worker(worker_id):\n",
128
+ " worker_seed = torch.initial_seed() % 2**32\n",
129
+ " np.random.seed(worker_seed)\n",
130
+ " random.seed(worker_seed)\n",
131
+ "\n",
132
+ "torch.set_float32_matmul_precision('high')\n",
133
+ "print(\"✅ PyTorch matmul precision set to 'high'\")\n",
134
+ "\n",
135
+ "# --- Device Setup ---\n",
136
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
137
+ "print(f\"Using device: {device}\")\n",
138
+ "\n",
139
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n",
140
+ "\n",
141
+ "VOCAB_SIZE = len(tokenizer)\n",
142
+ "print(f\"Vocab size: {VOCAB_SIZE}\")\n",
143
+ "\n",
144
+ "\n",
145
+ "# DATA LOADING & PREPARATION\n",
146
+ "\n",
147
+ "# --- 1. DEFINE THE FNET COLLATOR (FORCE FIXED LENGTH) ---\n",
148
+ "# This is crucial. It forces every sentence to be exactly 128 tokens.\n",
149
+ "fnet_collator = DataCollatorForSeq2Seq(\n",
150
+ " tokenizer=tokenizer,\n",
151
+ " padding=\"max_length\", # <--- FORCE PADDING\n",
152
+ " max_length=MAX_LENGTH, # <--- 128 (defined in your config)\n",
153
+ " pad_to_multiple_of=None\n",
154
+ ")\n",
155
+ "\n",
156
+ "# --- 2. LOAD DATASET ---\n",
157
+ "print(f\"Loading original bucketed samples from: {ORIGINAL_BUCKETED_REPO_ID}\")\n",
158
+ "original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)\n",
159
+ "\n",
160
+ "# --- 3. CREATE DATALOADERS (STANDARD FIXED SIZE) ---\n",
161
+ "FNET_PHYSICAL_BATCH_SIZE = 320\n",
162
+ "\n",
163
+ "g = torch.Generator()\n",
164
+ "g.manual_seed(SEED)\n",
165
+ "\n",
166
+ "train_dataloader = DataLoader(\n",
167
+ " original_datasets[\"train\"],\n",
168
+ " batch_size=FNET_PHYSICAL_BATCH_SIZE, # <--- FIXED BATCH SIZE (Safe from OOM)\n",
169
+ " shuffle=True, # <--- GLOBAL SHUFFLE\n",
170
+ " num_workers=8,\n",
171
+ " collate_fn=fnet_collator,\n",
172
+ " pin_memory=True,\n",
173
+ " worker_init_fn=seed_worker,\n",
174
+ " generator=g,\n",
175
+ ")\n",
176
+ "\n",
177
+ "val_dataloader = DataLoader(\n",
178
+ " original_datasets[\"validation\"],\n",
179
+ " batch_size=FNET_PHYSICAL_BATCH_SIZE,\n",
180
+ " collate_fn=fnet_collator,\n",
181
+ " num_workers=8,\n",
182
+ " pin_memory=True,\n",
183
+ " worker_init_fn=seed_worker,\n",
184
+ " generator=g,\n",
185
+ ")\n",
186
+ "\n",
187
+ "print(f\"Train Dataloader is now a STANDARD iterator.\")\n",
188
+ "print(f\"Physical Batch Size: {FNET_PHYSICAL_BATCH_SIZE}\")\n",
189
+ "print(f\"Gradient Accumulation: {GRAD_ACCUMULATION_STEPS}\")\n",
190
+ "print(f\"Effective Batch Size: {FNET_PHYSICAL_BATCH_SIZE * GRAD_ACCUMULATION_STEPS}\")\n",
191
+ "\n",
192
+ "# --- SANITY CHECK ---\n",
193
+ "print(\"\\n--- Running Sanity Check on new FNet DataLoader ---\")\n",
194
+ "train_dataloader.generator.manual_seed(SEED)\n",
195
+ "temp_iterator = iter(train_dataloader)\n",
196
+ "print(\"Shapes of first 3 batches (Should all be [64, 128]):\")\n",
197
+ "for i in range(3):\n",
198
+ " batch = next(temp_iterator)\n",
199
+ " print(f\" Batch {i+1}: input_ids shape = {batch['input_ids'].shape}\")\n",
200
+ "print(\"--- Sanity Check Complete ---\\n\")\n",
201
+ "# --- VERIFY SHUFFLE IS WORKING ---\n",
202
+ "print(\"🕵️ INSPECTING ONE BATCH 🕵️\")\n",
203
+ "\n",
204
+ "# Get one batch from your active train_dataloader\n",
205
+ "batch = next(iter(train_dataloader))\n",
206
+ "input_ids = batch['input_ids']\n",
207
+ "\n",
208
+ "# Calculate real lengths (ignoring padding)\n",
209
+ "# We count how many tokens are NOT the pad token (usually 0 or 58100)\n",
210
+ "real_lengths = (input_ids != tokenizer.pad_token_id).sum(dim=1)\n",
211
+ "\n",
212
+ "print(f\"Batch Shape: {input_ids.shape}\")\n",
213
+ "print(\"Random Sample of 20 lengths in this batch:\")\n",
214
+ "print(real_lengths[:20].tolist())\n",
215
+ "\n",
216
+ "# Check diversity\n",
217
+ "if real_lengths.float().std() < 5:\n",
218
+ " print(\"\\n⚠️ WARNING: LENGTHS LOOK CLUSTERED! (Bad shuffling)\")\n",
219
+ "else:\n",
220
+ " print(f\"\\n✅ PASSED: Lengths are highly variable (Std Dev: {real_lengths.float().std():.2f}). Shuffling is working.\")"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "metadata": {
226
+ "id": "cS4JvJGRhClv"
227
+ },
228
+ "source": [
229
+ "## Models"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {
236
+ "id": "SMhlM0YvO1A7"
237
+ },
238
+ "outputs": [],
239
+ "source": [
240
+ "import torch\n",
241
+ "import torch.nn as nn\n",
242
+ "import torch.nn.functional as F\n",
243
+ "import math\n",
244
+ "from x_transformers import Encoder, Decoder\n",
245
+ "\n",
246
+ "class RoPETransformer(nn.Module):\n",
247
+ " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n",
248
+ " super().__init__()\n",
249
+ " self.d_model = d_model\n",
250
+ " self.embedding = nn.Embedding(vocab_size, d_model)\n",
251
+ "\n",
252
+ " # We REMOVE self.pos_encoder (RoPE handles position internally)\n",
253
+ " self.dropout_layer = nn.Dropout(dropout)\n",
254
+ "\n",
255
+ " # --- x-transformers Encoder ---\n",
256
+ " self.encoder = Encoder(\n",
257
+ " dim = d_model,\n",
258
+ " depth = num_encoder_layers,\n",
259
+ " heads = num_heads,\n",
260
+ " attn_dim_head = d_model // num_heads,\n",
261
+ " ff_mult = dff / d_model,\n",
262
+ " rotary_pos_emb = True,\n",
263
+ " attn_flash = True,\n",
264
+ " attn_dropout = dropout,\n",
265
+ " ff_dropout = dropout,\n",
266
+ " use_rmsnorm = True\n",
267
+ " )\n",
268
+ "\n",
269
+ " # --- x-transformers Decoder ---\n",
270
+ " self.decoder = Decoder(\n",
271
+ " dim = d_model,\n",
272
+ " depth = num_decoder_layers,\n",
273
+ " heads = num_heads,\n",
274
+ " attn_dim_head = d_model // num_heads,\n",
275
+ " ff_mult = dff / d_model,\n",
276
+ " rotary_pos_emb = True,\n",
277
+ " cross_attend = True,\n",
278
+ " attn_flash = True,\n",
279
+ " attn_dropout = dropout,\n",
280
+ " ff_dropout = dropout,\n",
281
+ " use_rmsnorm = True\n",
282
+ " )\n",
283
+ "\n",
284
+ " self.final_linear = nn.Linear(d_model, vocab_size)\n",
285
+ " self.final_linear.weight = self.embedding.weight\n",
286
+ "\n",
287
+ " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n",
288
+ " # 1. Embeddings (No Absolute Positional Encoding added!)\n",
289
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
290
+ " src_emb = self.dropout_layer(src_emb)\n",
291
+ "\n",
292
+ " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n",
293
+ " tgt_emb = self.dropout_layer(tgt_emb)\n",
294
+ "\n",
295
+ " # 2. Mask Conversion\n",
296
+ " # User provides True=PAD. x-transformers wants True=KEEP.\n",
297
+ " # We invert the boolean mask using ~\n",
298
+ " enc_mask = ~src_padding_mask if src_padding_mask is not None else None\n",
299
+ " dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None\n",
300
+ "\n",
301
+ " # Note: 'tgt_mask' (causal mask) is handled automatically by x-transformers Decoder!\n",
302
+ " # We do NOT pass the square causal mask manually.\n",
303
+ "\n",
304
+ " # 3. Encoder\n",
305
+ " # x-transformers takes embeddings directly\n",
306
+ " memory = self.encoder(src_emb, mask=enc_mask)\n",
307
+ "\n",
308
+ " # 4. Decoder\n",
309
+ " # context = memory (from encoder)\n",
310
+ " # context_mask = mask for memory (encoder mask)\n",
311
+ " decoder_output = self.decoder(\n",
312
+ " tgt_emb,\n",
313
+ " context=memory,\n",
314
+ " mask=dec_mask,\n",
315
+ " context_mask=enc_mask\n",
316
+ " )\n",
317
+ "\n",
318
+ " return self.final_linear(decoder_output)\n",
319
+ "\n",
320
+ " # Keep your existing create_masks (used for Data Processing mostly)\n",
321
+ " def create_masks(self, src, tgt):\n",
322
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
323
+ " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n",
324
+ " # We still generate this for compatibility, though x-transformers handles causality internally\n",
325
+ " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n",
326
+ " sz=tgt.size(1), device=src.device, dtype=torch.bool\n",
327
+ " )\n",
328
+ " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n",
329
+ "\n",
330
+ " @torch.no_grad()\n",
331
+ " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n",
332
+ " self.eval()\n",
333
+ " # Create Mask (True=PAD)\n",
334
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
335
+ " # Invert for x-transformers (True=KEEP)\n",
336
+ " enc_mask = ~src_padding_mask\n",
337
+ "\n",
338
+ " # Encode\n",
339
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
340
+ " # No Pos Encoder\n",
341
+ " memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask)\n",
342
+ "\n",
343
+ " batch_size = src.shape[0]\n",
344
+ " # Expand for beams\n",
345
+ " memory = memory.repeat_interleave(num_beams, dim=0)\n",
346
+ " enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)\n",
347
+ "\n",
348
+ " initial_token = tokenizer.pad_token_id\n",
349
+ " beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n",
350
+ " beam_scores = torch.zeros(batch_size * num_beams, device=src.device)\n",
351
+ " finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)\n",
352
+ "\n",
353
+ " for _ in range(max_length - 1):\n",
354
+ " if finished_beams.all(): break\n",
355
+ "\n",
356
+ " # Embed beams\n",
357
+ " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)\n",
358
+ " # No Pos Encoder\n",
359
+ "\n",
360
+ " # Decode\n",
361
+ " # x-transformers automatically handles the causal masking for the sequence length of tgt_emb\n",
362
+ " decoder_output = self.decoder(\n",
363
+ " self.dropout_layer(tgt_emb),\n",
364
+ " context=memory,\n",
365
+ " context_mask=enc_mask\n",
366
+ " )\n",
367
+ "\n",
368
+ " logits = self.final_linear(decoder_output[:, -1, :])\n",
369
+ " log_probs = F.log_softmax(logits, dim=-1)\n",
370
+ "\n",
371
+ " # ... (Rest of your Beam Search Logic remains identical) ...\n",
372
+ " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n",
373
+ " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n",
374
+ "\n",
375
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
376
+ " if _ == 0:\n",
377
+ " total_scores = total_scores.view(batch_size, num_beams, -1)\n",
378
+ " total_scores[:, 1:, :] = -torch.inf\n",
379
+ " total_scores = total_scores.view(batch_size * num_beams, -1)\n",
380
+ " else:\n",
381
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
382
+ "\n",
383
+ " total_scores = total_scores.view(batch_size, -1)\n",
384
+ " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n",
385
+ "\n",
386
+ " beam_indices = top_indices // log_probs.shape[-1]\n",
387
+ " token_indices = top_indices % log_probs.shape[-1]\n",
388
+ "\n",
389
+ " batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)\n",
390
+ " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n",
391
+ "\n",
392
+ " beams = beams[effective_indices]\n",
393
+ " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n",
394
+ " beam_scores = top_scores.view(-1)\n",
395
+ " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n",
396
+ "\n",
397
+ " final_beams = beams.view(batch_size, num_beams, -1)\n",
398
+ " final_scores = beam_scores.view(batch_size, num_beams)\n",
399
+ " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n",
400
+ " best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]\n",
401
+ " self.train()\n",
402
+ " return best_beams\n",
403
+ "\n",
404
+ "class RMSNorm(nn.Module):\n",
405
+ " def __init__(self, dim, eps=1e-8):\n",
406
+ " super().__init__()\n",
407
+ " self.eps = eps\n",
408
+ " self.gamma = nn.Parameter(torch.ones(dim))\n",
409
+ "\n",
410
+ " def forward(self, x):\n",
411
+ " # 1. Calculate the mean of the squares\n",
412
+ " mean_square = x.pow(2).mean(dim=-1, keepdim=True)\n",
413
+ "\n",
414
+ " # 2. Calculate the inverse square root (1 / RMS)\n",
415
+ " # We add eps before the sqrt for stability\n",
416
+ " inv_rms = torch.rsqrt(mean_square + self.eps)\n",
417
+ "\n",
418
+ " # 3. Normalize and scale\n",
419
+ " return x * inv_rms * self.gamma\n",
420
+ "\n",
421
+ "\n",
422
+ "class FNetBlock(nn.Module):\n",
423
+ " def __init__(self, d_model, d_ff, dropout):\n",
424
+ " super().__init__()\n",
425
+ " self.norm_mix = nn.LayerNorm(d_model) # LayerNorm is safer for FNet than RMSNorm\n",
426
+ " self.norm_ff = nn.LayerNorm(d_model)\n",
427
+ "\n",
428
+ " self.ff = nn.Sequential(\n",
429
+ " nn.Linear(d_model, d_ff),\n",
430
+ " nn.GELU(),\n",
431
+ " nn.Dropout(dropout),\n",
432
+ " nn.Linear(d_ff, d_model),\n",
433
+ " nn.Dropout(dropout)\n",
434
+ " )\n",
435
+ "\n",
436
+ " def forward(self, x):\n",
437
+ " # 1. Fourier Mixing Branch\n",
438
+ " residual = x\n",
439
+ " x = self.norm_mix(x)\n",
440
+ "\n",
441
+ " # --- THE FIX ---\n",
442
+ " with torch.cuda.amp.autocast(enabled=False):\n",
443
+ " x = x.float()\n",
444
+ " # norm='ortho' makes the FFT energy-preserving.\n",
445
+ " # Output magnitude will match input magnitude (~1).\n",
446
+ " x = torch.fft.fftn(x, dim=(-2, -1), norm='ortho').real\n",
447
+ " x = x.to(dtype=residual.dtype)\n",
448
+ " # ---------------\n",
449
+ "\n",
450
+ " # Now 'x' and 'residual' have roughly same magnitude.\n",
451
+ " # The skip connection works again.\n",
452
+ " x = x + residual\n",
453
+ "\n",
454
+ " # 2. Feed Forward Branch\n",
455
+ " residual = x\n",
456
+ " x = self.norm_ff(x)\n",
457
+ " x = self.ff(x)\n",
458
+ " return x + residual\n",
459
+ "\n",
460
+ "\n",
461
+ "class FNetEncoder(nn.Module):\n",
462
+ " def __init__(self, depth, d_model, d_ff, dropout):\n",
463
+ " super().__init__()\n",
464
+ " self.layers = nn.ModuleList([\n",
465
+ " FNetBlock(d_model, d_ff, dropout) for _ in range(depth)\n",
466
+ " ])\n",
467
+ " # [FIX] Use LayerNorm here to match the blocks\n",
468
+ " self.norm_out = nn.LayerNorm(d_model)\n",
469
+ "\n",
470
+ " def forward(self, x):\n",
471
+ " for layer in self.layers:\n",
472
+ " x = layer(x)\n",
473
+ " return self.norm_out(x)\n",
474
+ "\n",
475
+ "# --- Main Hybrid Model ---\n",
476
+ "\n",
477
+ "class FNetHybridTransformer(nn.Module):\n",
478
+ " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n",
479
+ " super().__init__()\n",
480
+ " self.d_model = d_model\n",
481
+ "\n",
482
+ " # Shared Embeddings\n",
483
+ " # padding_idx=tokenizer.pad_token_id forces the vector at this index to be strict ZEROS.\n",
484
+ " # It does not have gradients, it stays zero forever.\n",
485
+ " self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=tokenizer.pad_token_id)\n",
486
+ "\n",
487
+ " # FNet REQUIRES Absolute Positional Embeddings because FFT mixes information\n",
488
+ " # but doesn't inherently understand sequence order like RoPE/RNNs do initially.\n",
489
+ " self.pos_embedding = nn.Embedding(max_length, d_model)\n",
490
+ "\n",
491
+ " self.dropout_layer = nn.Dropout(dropout)\n",
492
+ "\n",
493
+ " # --- Custom FNet Encoder ---\n",
494
+ " self.encoder = FNetEncoder(\n",
495
+ " depth=num_encoder_layers,\n",
496
+ " d_model=d_model,\n",
497
+ " d_ff=dff,\n",
498
+ " dropout=dropout\n",
499
+ " )\n",
500
+ "\n",
501
+ " # --- x-transformers Decoder (Retains RoPE) ---\n",
502
+ " self.decoder = Decoder(\n",
503
+ " dim=d_model,\n",
504
+ " depth=num_decoder_layers,\n",
505
+ " heads=num_heads,\n",
506
+ " attn_dim_head=d_model // num_heads,\n",
507
+ " ff_mult=dff / d_model,\n",
508
+ " rotary_pos_emb=True, # Decoder still uses RoPE\n",
509
+ " cross_attend=True,\n",
510
+ " attn_flash=True,\n",
511
+ " attn_dropout=dropout,\n",
512
+ " ff_dropout=dropout,\n",
513
+ " use_rmsnorm=True\n",
514
+ " )\n",
515
+ "\n",
516
+ " self.final_linear = nn.Linear(d_model, vocab_size)\n",
517
+ " self.final_linear.weight = self.embedding.weight\n",
518
+ "\n",
519
+ " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n",
520
+ " # 1. Embeddings\n",
521
+ " # Source (Encoder) gets Absolute Positional Embeddings\n",
522
+ " B, L_src = src.shape\n",
523
+ " pos_ids = torch.arange(L_src, device=src.device).unsqueeze(0)\n",
524
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
525
+ " src_emb = src_emb + self.pos_embedding(pos_ids)\n",
526
+ " src_emb = self.dropout_layer(src_emb)\n",
527
+ "\n",
528
+ " # Target (Decoder) gets NO Positional Embeddings here (RoPE handles it inside Decoder)\n",
529
+ " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n",
530
+ " tgt_emb = self.dropout_layer(tgt_emb)\n",
531
+ "\n",
532
+ " # 2. Prepare Masks\n",
533
+ " # x-transformers requires True = Keep, False = Mask\n",
534
+ " # Your dataloader provides True = Pad\n",
535
+ " enc_mask = ~src_padding_mask if src_padding_mask is not None else None\n",
536
+ " dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None\n",
537
+ "\n",
538
+ " # 3. FNet Encoder\n",
539
+ " # Note: FNet mixes ALL tokens (including padding).\n",
540
+ " memory = self.encoder(src_emb)\n",
541
+ "\n",
542
+ " # CRITICAL: Zero out padding positions in encoder output so Decoder doesn't attend to them.\n",
543
+ " if src_padding_mask is not None:\n",
544
+ " memory = memory.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)\n",
545
+ "\n",
546
+ " # 4. RoPE Decoder\n",
547
+ " # The decoder uses RoPE for self-attention on 'tgt',\n",
548
+ " # and standard cross-attention to 'memory' (FNet output).\n",
549
+ " decoder_output = self.decoder(\n",
550
+ " tgt_emb,\n",
551
+ " context=memory,\n",
552
+ " mask=dec_mask,\n",
553
+ " context_mask=enc_mask\n",
554
+ " )\n",
555
+ "\n",
556
+ " return self.final_linear(decoder_output)\n",
557
+ "\n",
558
+ " def create_masks(self, src, tgt):\n",
559
+ " # Standard mask creation (Same as your original)\n",
560
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
561
+ " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n",
562
+ " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n",
563
+ " sz=tgt.size(1), device=src.device, dtype=torch.bool\n",
564
+ " )\n",
565
+ " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n",
566
+ "\n",
567
+ " @torch.no_grad()\n",
568
+ " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n",
569
+ " self.eval()\n",
570
+ " B, L_src = src.shape\n",
571
+ "\n",
572
+ " # 1. Encode with FNet\n",
573
+ " pos_ids = torch.arange(L_src, device=src.device).unsqueeze(0)\n",
574
+ " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n",
575
+ " src_emb = src_emb + self.pos_embedding(pos_ids)\n",
576
+ "\n",
577
+ " memory = self.encoder(self.dropout_layer(src_emb))\n",
578
+ "\n",
579
+ " # Masking padding in memory\n",
580
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
581
+ " memory = memory.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)\n",
582
+ "\n",
583
+ " # Prepare for Decoder (x-transformers style mask: True=Keep)\n",
584
+ " enc_mask = ~src_padding_mask\n",
585
+ "\n",
586
+ " # --- BEAM SEARCH SETUP ---\n",
587
+ " # Expand memory for beams\n",
588
+ " memory = memory.repeat_interleave(num_beams, dim=0)\n",
589
+ " enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)\n",
590
+ "\n",
591
+ " initial_token = tokenizer.pad_token_id\n",
592
+ " beams = torch.full((B * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n",
593
+ " beam_scores = torch.zeros(B * num_beams, device=src.device)\n",
594
+ " finished_beams = torch.zeros(B * num_beams, dtype=torch.bool, device=src.device)\n",
595
+ "\n",
596
+ " for _ in range(max_length - 1):\n",
597
+ " if finished_beams.all(): break\n",
598
+ "\n",
599
+ " # Decoder Step (RoPE handled internally)\n",
600
+ " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)\n",
601
+ "\n",
602
+ " decoder_output = self.decoder(\n",
603
+ " self.dropout_layer(tgt_emb),\n",
604
+ " context=memory,\n",
605
+ " context_mask=enc_mask\n",
606
+ " )\n",
607
+ "\n",
608
+ " logits = self.final_linear(decoder_output[:, -1, :])\n",
609
+ " log_probs = F.log_softmax(logits, dim=-1)\n",
610
+ "\n",
611
+ " # --- STANDARD BEAM LOGIC (No changes needed here) ---\n",
612
+ " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n",
613
+ " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n",
614
+ "\n",
615
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
616
+ " if _ == 0:\n",
617
+ " total_scores = total_scores.view(B, num_beams, -1)\n",
618
+ " total_scores[:, 1:, :] = -torch.inf\n",
619
+ " total_scores = total_scores.view(B * num_beams, -1)\n",
620
+ " else:\n",
621
+ " total_scores = beam_scores.unsqueeze(1) + log_probs\n",
622
+ "\n",
623
+ " total_scores = total_scores.view(B, -1)\n",
624
+ " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n",
625
+ "\n",
626
+ " beam_indices = top_indices // log_probs.shape[-1]\n",
627
+ " token_indices = top_indices % log_probs.shape[-1]\n",
628
+ "\n",
629
+ " batch_indices = torch.arange(B, device=src.device).unsqueeze(1)\n",
630
+ " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n",
631
+ "\n",
632
+ " beams = beams[effective_indices]\n",
633
+ " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n",
634
+ " beam_scores = top_scores.view(-1)\n",
635
+ " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n",
636
+ "\n",
637
+ " final_beams = beams.view(B, num_beams, -1)\n",
638
+ " final_scores = beam_scores.view(B, num_beams)\n",
639
+ " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n",
640
+ " best_beams = final_beams[torch.arange(B), normalized_scores.argmax(1), :]\n",
641
+ " self.train()\n",
642
+ " return best_beams"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "code",
647
+ "source": [
648
+ "def count_parameters(model):\n",
649
+ " table_data = []\n",
650
+ " total_params = 0\n",
651
+ " trainable_params = 0\n",
652
+ "\n",
653
+ " # 1. Global Counts\n",
654
+ " for p in model.parameters():\n",
655
+ " total_params += p.numel()\n",
656
+ " if p.requires_grad:\n",
657
+ " trainable_params += p.numel()\n",
658
+ "\n",
659
+ " print(\"=\"*40)\n",
660
+ " print(f\"📊 MODEL STATISTICS\")\n",
661
+ " print(\"=\"*40)\n",
662
+ " print(f\"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)\")\n",
663
+ " print(f\"Trainable Parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)\")\n",
664
+ " print(\"-\" * 40)\n",
665
+ "\n",
666
+ " # 2. Section Breakdown\n",
667
+ " def get_params(module):\n",
668
+ " return sum(p.numel() for p in module.parameters())\n",
669
+ "\n",
670
+ " if hasattr(model, 'encoder'):\n",
671
+ " enc_p = get_params(model.encoder)\n",
672
+ " print(f\" • Encoder (FNet): {enc_p:,} ({enc_p/1e6:.2f}M)\")\n",
673
+ "\n",
674
+ " if hasattr(model, 'decoder'):\n",
675
+ " dec_p = get_params(model.decoder)\n",
676
+ " print(f\" • Decoder (RoPE): {dec_p:,} ({dec_p/1e6:.2f}M)\")\n",
677
+ "\n",
678
+ " if hasattr(model, 'embedding'):\n",
679
+ " emb_p = get_params(model.embedding)\n",
680
+ " print(f\" • Embeddings: {emb_p:,} ({emb_p/1e6:.2f}M)\")\n",
681
+ "\n",
682
+ " print(\"=\"*40)\n",
683
+ "\n"
684
+ ],
685
+ "metadata": {
686
+ "id": "wpmz-H9Slko1"
687
+ },
688
+ "execution_count": null,
689
+ "outputs": []
690
+ },
691
+ {
692
+ "cell_type": "markdown",
693
+ "metadata": {
694
+ "id": "Zd3AFTmhrCJq"
695
+ },
696
+ "source": [
697
+ "## Functions (Loss, Eval etc)"
698
+ ]
699
+ },
700
+ {
701
+ "cell_type": "code",
702
+ "execution_count": null,
703
+ "metadata": {
704
+ "id": "Te1qTyUKrDEd"
705
+ },
706
+ "outputs": [],
707
+ "source": [
708
+ "\n",
709
+ "translation_loss_fn = nn.CrossEntropyLoss(\n",
710
+ " ignore_index=-100, # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.\n",
711
+ " label_smoothing=LABEL_SMOOTHING_EPSILON\n",
712
+ ")\n",
713
+ "def calculate_combined_loss(model_outputs, target_labels):\n",
714
+ " \"\"\"Calculates the loss based on the model's output structure.\"\"\"\n",
715
+ " logits = model_outputs\n",
716
+ " translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))\n",
717
+ " loss_dict = {'total': translation_loss.item()}\n",
718
+ " return translation_loss, loss_dict\n",
719
+ "\n",
720
+ "from torchmetrics.text import SacreBLEUScore\n",
721
+ "\n",
722
+ "def evaluate(model, dataloader, device):\n",
723
+ " # Use SacreBLEUScore (defaults to '13a' tokenizer, the WMT standard)\n",
724
+ " metric = SacreBLEUScore().to(device)\n",
725
+ "\n",
726
+ " model.eval()\n",
727
+ "\n",
728
+ " # Use no_grad to save memory and speed up validation\n",
729
+ " with torch.no_grad():\n",
730
+ " for batch in tqdm(dataloader, desc=\"Evaluating\", leave=False):\n",
731
+ " input_ids = batch['input_ids'].to(device)\n",
732
+ " labels = batch['labels']\n",
733
+ "\n",
734
+ " # Generate predictions\n",
735
+ " generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n",
736
+ "\n",
737
+ " # Decode predictions\n",
738
+ " pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
739
+ "\n",
740
+ " # Decode labels (Fixing -100 padding)\n",
741
+ " labels[labels == -100] = tokenizer.pad_token_id\n",
742
+ " ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
743
+ "\n",
744
+ " # Update Metric\n",
745
+ " # SacreBLEU expects references as a list of lists: [[ref1], [ref2], ...]\n",
746
+ " formatted_refs = [[ref] for ref in ref_texts]\n",
747
+ " metric.update(pred_texts, formatted_refs)\n",
748
+ "\n",
749
+ " model.train()\n",
750
+ "\n",
751
+ " # Compute returns a tensor, .item() converts it to a standard python float\n",
752
+ " return metric.compute().item()\n",
753
+ "\n",
754
+ "\n",
755
+ "\n",
756
+ "## WARNING! THIS CAN'T BE USED FOR FNET\n",
757
+ "def generate_sample_translations(model, device, sentences_de):\n",
758
+ " \"\"\"Generates and prints sample translations using beam search.\"\"\"\n",
759
+ " print(\"\\n--- Generating Sample Translations (with Beam Search) ---\")\n",
760
+ " orig_model = getattr(model, '_orig_mod', model)\n",
761
+ " orig_model.eval()\n",
762
+ "\n",
763
+ " inputs = tokenizer(sentences_de, return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_LENGTH)\n",
764
+ " input_ids = inputs.input_ids.to(device)\n",
765
+ " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n",
766
+ "\n",
767
+ " translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
768
+ " for src, out in zip(sentences_de, translations):\n",
769
+ " print(f\" DE Source: {src}\")\n",
770
+ " print(f\" EN Output: {out}\")\n",
771
+ " print(\"-\" * 20)\n",
772
+ " orig_model.train()\n",
773
+ "\n",
774
+ "sample_sentences_de_for_tracking = [\n",
775
+ " \"Eine Katze sitzt auf der Matte.\",\n",
776
+ " \"Ein Mann in einem roten Hemd liest ein Buch.\",\n",
777
+ " \"Was ist die Hauptstadt von Deutschland?\",\n",
778
+ " \"Ich gehe ins Kino, weil der Film sehr gut ist.\",\n",
779
+ "]\n",
780
+ "\n",
781
+ "def init_other_linear_weights(m):\n",
782
+ " if isinstance(m, nn.Linear):\n",
783
+ " # The 'is not' check correctly skips the final_linear layer,\n",
784
+ " # leaving its weights tied to the correctly initialized embeddings.\n",
785
+ " if m is not getattr(model, '_orig_mod', model).final_linear:\n",
786
+ " nn.init.xavier_uniform_(m.weight)\n",
787
+ " if m.bias is not None:\n",
788
+ " nn.init.zeros_(m.bias)\n",
789
+ "\n",
790
+ "\n"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": null,
796
+ "metadata": {
797
+ "id": "YwPXbSwR50I2"
798
+ },
799
+ "outputs": [],
800
+ "source": [
801
+ "import json\n",
802
+ "import os\n",
803
+ "import subprocess\n",
804
+ "import torch\n",
805
+ "import hashlib\n",
806
+ "import sys\n",
807
+ "import shutil\n",
808
+ "\n",
809
+ "# This logger will be configured and used in the main training script\n",
810
+ "import logging\n",
811
+ "logger = logging.getLogger(__name__)\n",
812
+ "\n",
813
+ "\n",
814
+ "def log_to_run_specific_file(run_dir):\n",
815
+ " run_log_path = os.path.join(run_dir, \"run_log.txt\")\n",
816
+ " file_handler = logging.FileHandler(run_log_path)\n",
817
+ " file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))\n",
818
+ " logger.addHandler(file_handler)\n",
819
+ " return file_handler\n",
820
+ "\n",
821
+ "def log_configurations(log_dir, config_vars):\n",
822
+ " # (Same as your provided function)\n",
823
+ " config_path = os.path.join(log_dir, \"config.json\")\n",
824
+ " try:\n",
825
+ " with open(config_path, 'w') as f:\n",
826
+ " serializable_configs = {k: v for k, v in config_vars.items() if isinstance(v, (int, float, str, bool, list, dict, type(None)))}\n",
827
+ " json.dump(serializable_configs, f, indent=4)\n",
828
+ " logger.info(f\"Configurations saved to {config_path}\")\n",
829
+ " except Exception as e:\n",
830
+ " logger.error(f\"Could not save configurations: {e}\")\n",
831
+ "\n",
832
+ "def log_environment(log_dir):\n",
833
+ " # (Same as your provided function)\n",
834
+ " env_path = os.path.join(log_dir, \"environment.txt\")\n",
835
+ " try:\n",
836
+ " with open(env_path, 'w') as f:\n",
837
+ " f.write(f\"--- Timestamp (UTC): {datetime.datetime.utcnow().isoformat()} ---\\n\")\n",
838
+ " f.write(f\"Python Version: {sys.version}\\n\")\n",
839
+ " f.write(f\"PyTorch Version: {torch.__version__}\\n\")\n",
840
+ " f.write(f\"CUDA Available: {torch.cuda.is_available()}\\n\")\n",
841
+ " if torch.cuda.is_available():\n",
842
+ " f.write(f\"CUDA Version: {torch.version.cuda}\\n\")\n",
843
+ " f.write(f\"CuDNN Version: {torch.backends.cudnn.version()}\\n\")\n",
844
+ " f.write(f\"Number of GPUs: {torch.cuda.device_count()}\\n\")\n",
845
+ " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n",
846
+ " f.write(\"\\n--- Full pip freeze ---\\n\")\n",
847
+ " result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'], stdout=subprocess.PIPE, text=True, check=True)\n",
848
+ " f.write(result.stdout)\n",
849
+ " logger.info(f\"Environment info saved to {env_path}\")\n",
850
+ " except Exception as e:\n",
851
+ " logger.error(f\"Could not save environment info: {e}\")\n",
852
+ "\n",
853
+ "def log_code_snapshot(log_dir, script_path):\n",
854
+ " # NOTE: In Colab, you must save your notebook as a .py file for this to work.\n",
855
+ " # For example, file -> \"Save a copy as .py\"\n",
856
+ " code_dir = os.path.join(log_dir, \"code_snapshot\")\n",
857
+ " os.makedirs(code_dir, exist_ok=True)\n",
858
+ " if script_path and os.path.exists(script_path):\n",
859
+ " try:\n",
860
+ " shutil.copy(script_path, os.path.join(code_dir, os.path.basename(script_path)))\n",
861
+ " logger.info(f\"Copied script '{script_path}' to snapshot directory for verification.\")\n",
862
+ " except Exception as e:\n",
863
+ " logger.error(f\"Could not copy script for snapshot: {e}\")\n",
864
+ " else:\n",
865
+ " logger.warning(f\"Code Snapshot: Script path '{script_path}' not found. SKIPPING.\")\n",
866
+ "\n",
867
+ "def get_file_hash(filepath):\n",
868
+ " # (Same as your provided function)\n",
869
+ " sha256_hash = hashlib.sha256()\n",
870
+ " try:\n",
871
+ " with open(filepath, \"rb\") as f:\n",
872
+ " for byte_block in iter(lambda: f.read(4096), b\"\"):\n",
873
+ " sha256_hash.update(byte_block)\n",
874
+ " return sha256_hash.hexdigest()\n",
875
+ " except Exception as e:\n",
876
+ " logger.error(f\"Could not generate hash for {filepath}: {e}\")\n",
877
+ " return None\n",
878
+ "\n",
879
+ "def create_checksum_file(run_dir, artifacts_dict):\n",
880
+ " checksum_file_path = os.path.join(run_dir, \"checksums.sha256\")\n",
881
+ " logger.info(f\"--- Creating digital fingerprints for key artifacts ---\")\n",
882
+ " with open(checksum_file_path, \"w\") as f:\n",
883
+ " f.write(f\"SHA256 Checksums for run: {os.path.basename(run_dir)}\\n\")\n",
884
+ " for name, path in artifacts_dict.items():\n",
885
+ " if path and os.path.exists(path):\n",
886
+ " file_hash = get_file_hash(path)\n",
887
+ " if file_hash:\n",
888
+ " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n",
889
+ " logger.info(log_message)\n",
890
+ " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n",
891
+ " else:\n",
892
+ " logger.warning(f\" - Skipped hashing '{name}', file not found: {path}\")\n",
893
+ " logger.info(f\"Checksums saved to {checksum_file_path}\")\n",
894
+ "\n",
895
+ "def init_weights_kaiming(m):\n",
896
+ " \"\"\"\n",
897
+ " Applies Kaiming He initialization to Linear layers.\n",
898
+ " This is the standard, superior way to initialize deep Transformers.\n",
899
+ " NOTE: We will handle the Embedding layer separately.\n",
900
+ " \"\"\"\n",
901
+ "\n",
902
+ " if isinstance(m, nn.Linear):\n",
903
+ " nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) # a=sqrt(5) mimics default PyTorch for LeakyReLU\n",
904
+ " if m.bias is not None:\n",
905
+ " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)\n",
906
+ " bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n",
907
+ " nn.init.uniform_(m.bias, -bound, bound)\n",
908
+ "\n",
909
+ "\n",
910
+ "def init_weights_fnet(m):\n",
911
+ " \"\"\"\n",
912
+ " Specific initialization for FNet Hybrid.\n",
913
+ " FNet is essentially a BERT-like encoder, so we use BERT-style initialization\n",
914
+ " (Truncated Normal or Xavier) rather than Kaiming.\n",
915
+ " \"\"\"\n",
916
+ " if isinstance(m, nn.Linear):\n",
917
+ " # Xavier (Glorot) Uniform is the standard for Transformer/FNet attention/FFN layers\n",
918
+ " nn.init.xavier_uniform_(m.weight)\n",
919
+ " if m.bias is not None:\n",
920
+ " nn.init.zeros_(m.bias)\n",
921
+ "\n",
922
+ " elif isinstance(m, nn.Embedding):\n",
923
+ " # Critical: Keep embedding variance low (0.02)\n",
924
+ " nn.init.normal_(m.weight, mean=0.0, std=0.02)\n",
925
+ "\n",
926
+ " # Handle the RMSNorms if they have learnable parameters\n",
927
+ " elif isinstance(m, (nn.LayerNorm, RMSNorm)):\n",
928
+ " if hasattr(m, 'weight') and m.weight is not None:\n",
929
+ " nn.init.ones_(m.weight)\n",
930
+ " if hasattr(m, 'bias') and m.bias is not None:\n",
931
+ " nn.init.zeros_(m.bias)\n",
932
+ "\n"
933
+ ]
934
+ },
935
+ {
936
+ "cell_type": "markdown",
937
+ "metadata": {
938
+ "id": "ijTUk5dHu494"
939
+ },
940
+ "source": [
941
+ "## Training Loop"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "code",
946
+ "execution_count": null,
947
+ "metadata": {
948
+ "id": "pyHZ1moluyA2"
949
+ },
950
+ "outputs": [],
951
+ "source": [
952
+ "if __name__ == '__main__':\n",
953
+ "\n",
954
+ " experiment_name = f\"{MODEL_CHOICE}\"\n",
955
+ " CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name)\n",
956
+ " SAVE_DIR = os.path.join(CURRENT_RUN_DIR, \"models\")\n",
957
+ " LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, \"tensorboard_logs\")\n",
958
+ " LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, \"run_log.txt\")\n",
959
+ "\n",
960
+ " os.makedirs(SAVE_DIR, exist_ok=True)\n",
961
+ " os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)\n",
962
+ "\n",
963
+ " logging.basicConfig(\n",
964
+ " level=logging.INFO,\n",
965
+ " format='%(asctime)s [%(levelname)s] %(message)s',\n",
966
+ " handlers=[logging.FileHandler(LOG_FILE_TXT), logging.StreamHandler(sys.stdout)],\n",
967
+ " force=True\n",
968
+ " )\n",
969
+ " logger = logging.getLogger(__name__)\n",
970
+ " writer = SummaryWriter(LOG_DIR_TENSORBOARD)\n",
971
+ "\n",
972
+ " logger.info(f\"--- LAUNCHING EXPERIMENT: {experiment_name} ---\")\n",
973
+ "\n",
974
+ " all_configs = {k: v for k, v in globals().items() if k.isupper()}\n",
975
+ " log_configurations(CURRENT_RUN_DIR, all_configs)\n",
976
+ " log_environment(CURRENT_RUN_DIR)\n",
977
+ "\n",
978
+ " logger.info(f\"--- Initializing FNetHybridTransformer ---\")\n",
979
+ " model = FNetHybridTransformer(\n",
980
+ " num_encoder_layers=NUM_ENCODER_LAYERS,\n",
981
+ " num_decoder_layers=NUM_DECODER_LAYERS,\n",
982
+ " num_heads=NUM_HEADS,\n",
983
+ " d_model=D_MODEL,\n",
984
+ " dff=D_FF,\n",
985
+ " vocab_size=VOCAB_SIZE,\n",
986
+ " max_length=MAX_LENGTH,\n",
987
+ " dropout=DROPOUT\n",
988
+ " )\n",
989
+ "\n",
990
+ " model.apply(init_weights_fnet)\n",
991
+ " nn.init.normal_(model.pos_embedding.weight, mean=0.0, std=0.02)\n",
992
+ " model.final_linear.weight = model.embedding.weight\n",
993
+ "\n",
994
+ " model.to(device)\n",
995
+ " count_parameters(model)\n",
996
+ "\n",
997
+ " # 4. SETUP OPTIMIZER\n",
998
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, betas=(0.9, 0.98),\n",
999
+ " eps=1e-9, weight_decay=WEIGHT_DECAY)\n",
1000
+ "\n",
1001
+ " # Scheduler\n",
1002
+ " scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=WARMUP_STEPS,\n",
1003
+ " num_training_steps=TARGET_TRAINING_STEPS)\n",
1004
+ " scaler = torch.cuda.amp.GradScaler()\n",
1005
+ "\n",
1006
+ "# --- AUTO-RESUME LOGIC (SMARTER VERSION) ---\n",
1007
+ " global_step = 0\n",
1008
+ " best_bleu = 0.0\n",
1009
+ " LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"last.pt\")\n",
1010
+ " BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"best.pt\")\n",
1011
+ "\n",
1012
+ " # 1. Try to find the latest checkpoint (if it exists)\n",
1013
+ " if os.path.exists(LAST_CHECKPOINT_PATH):\n",
1014
+ " logger.info(f\"🔄 Found checkpoint at {LAST_CHECKPOINT_PATH}. Resuming...\")\n",
1015
+ " checkpoint = torch.load(LAST_CHECKPOINT_PATH, map_location=device)\n",
1016
+ "\n",
1017
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
1018
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1019
+ " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
1020
+ " scaler.load_state_dict(checkpoint['scaler_state_dict'])\n",
1021
+ "\n",
1022
+ " global_step = checkpoint['global_step']\n",
1023
+ " best_bleu = checkpoint.get('best_bleu', 0.0)\n",
1024
+ " logger.info(f\" ✅ Resumed from Step {global_step} (LAST)\")\n",
1025
+ "\n",
1026
+ " # 2. If no LAST, try to find the BEST checkpoint (Fall back to this!)\n",
1027
+ " elif os.path.exists(BEST_CHECKPOINT_PATH):\n",
1028
+ " logger.info(f\"🔙 'last.pt' not found. Falling back to BEST checkpoint: {BEST_CHECKPOINT_PATH}\")\n",
1029
+ " checkpoint = torch.load(BEST_CHECKPOINT_PATH, map_location=device)\n",
1030
+ "\n",
1031
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
1032
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1033
+ " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n",
1034
+ " scaler.load_state_dict(checkpoint['scaler_state_dict'])\n",
1035
+ "\n",
1036
+ " global_step = checkpoint['global_step']\n",
1037
+ " best_bleu = checkpoint.get('best_bleu', 0.0)\n",
1038
+ " logger.info(f\" ✅ Resumed from Step {global_step} (BEST)\")\n",
1039
+ "\n",
1040
+ " # 3. Start Fresh\n",
1041
+ " else:\n",
1042
+ " logger.info(\"🆕 No checkpoint found. Starting fresh training.\")\n",
1043
+ " # 5. TRAINING LOOP\n",
1044
+ " model.train()\n",
1045
+ "\n",
1046
+ " # Resume progress bar from global_step\n",
1047
+ " progress_bar = tqdm(total=TARGET_TRAINING_STEPS, initial=global_step, desc=\"Training Steps\")\n",
1048
+ " training_complete = False\n",
1049
+ "\n",
1050
+ " # Initialize gradients\n",
1051
+ " optimizer.zero_grad(set_to_none=True)\n",
1052
+ "\n",
1053
+ " # We iterate until global_step reaches the target\n",
1054
+ " epoch = 0\n",
1055
+ " while not training_complete:\n",
1056
+ " train_dataloader.generator.manual_seed(SEED + epoch)\n",
1057
+ " epoch += 1\n",
1058
+ "\n",
1059
+ " for batch_idx, batch in enumerate(train_dataloader):\n",
1060
+ " if global_step >= TARGET_TRAINING_STEPS:\n",
1061
+ " training_complete = True\n",
1062
+ " break\n",
1063
+ "\n",
1064
+ " input_ids = batch['input_ids'].to(device, non_blocking=True)\n",
1065
+ " labels = batch['labels'].to(device, non_blocking=True)\n",
1066
+ "\n",
1067
+ " decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)\n",
1068
+ " decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)\n",
1069
+ " decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id\n",
1070
+ " target_labels = labels\n",
1071
+ "\n",
1072
+ " src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)\n",
1073
+ " tgt_padding_mask[:, 0] = False\n",
1074
+ "\n",
1075
+ " with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n",
1076
+ " model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask,\n",
1077
+ " tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask,\n",
1078
+ " tgt_mask=tgt_mask)\n",
1079
+ " loss, loss_components = calculate_combined_loss(model_outputs, target_labels)\n",
1080
+ "\n",
1081
+ " # --- GRADIENT ACCUMULATION SCALING ---\n",
1082
+ " loss = loss / GRAD_ACCUMULATION_STEPS\n",
1083
+ "\n",
1084
+ " # Accumulate gradients (no optimizer step yet)\n",
1085
+ " scaler.scale(loss).backward()\n",
1086
+ "\n",
1087
+ " # --- OPTIMIZER STEP (Conditional) ---\n",
1088
+ " if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0:\n",
1089
+ " scaler.unscale_(optimizer)\n",
1090
+ " total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
1091
+ "\n",
1092
+ " scaler.step(optimizer)\n",
1093
+ " scaler.update()\n",
1094
+ " scheduler.step()\n",
1095
+ "\n",
1096
+ " # Reset gradients\n",
1097
+ " optimizer.zero_grad(set_to_none=True)\n",
1098
+ "\n",
1099
+ " global_step += 1\n",
1100
+ " progress_bar.update(1)\n",
1101
+ " lr = scheduler.get_last_lr()[0]\n",
1102
+ "\n",
1103
+ " if global_step % 20 == 0:\n",
1104
+ " # Scale loss back up for logging purposes\n",
1105
+ " logged_loss = loss.item() * GRAD_ACCUMULATION_STEPS\n",
1106
+ " writer.add_scalar('train/loss', logged_loss, global_step)\n",
1107
+ " writer.add_scalar('train/learning_rate', lr, global_step)\n",
1108
+ " writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step)\n",
1109
+ " progress_bar.set_postfix(\n",
1110
+ " loss=f\"{logged_loss:.2f}\",\n",
1111
+ " lr=f\"{lr:.2e}\",\n",
1112
+ " grad=f\"{total_grad_norm.item():.2f}\" # Showing Gradients\n",
1113
+ " )\n",
1114
+ "\n",
1115
+ " # --- PERIODIC SAVING (Every 500 Steps) ---\n",
1116
+ " # Saves you if Colab crashes mid-epoch\n",
1117
+ " if global_step % 500 == 0:\n",
1118
+ " torch.save({\n",
1119
+ " 'global_step': global_step,\n",
1120
+ " 'model_state_dict': model.state_dict(),\n",
1121
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
1122
+ " 'scheduler_state_dict': scheduler.state_dict(),\n",
1123
+ " 'scaler_state_dict': scaler.state_dict(),\n",
1124
+ " 'best_bleu': best_bleu\n",
1125
+ " }, LAST_CHECKPOINT_PATH)\n",
1126
+ "\n",
1127
+ " # --- VALIDATION CHECK ---\n",
1128
+ " if global_step in VALIDATION_SCHEDULE:\n",
1129
+ " logger.info(f\"\\n--- Validation at Step {global_step} ---\")\n",
1130
+ " bleu_score = evaluate(model, val_dataloader, device)\n",
1131
+ " writer.add_scalar('validation/bleu', bleu_score, global_step)\n",
1132
+ " logger.info(f\"Validation BLEU: {bleu_score:.4f} (Best: {best_bleu:.4f})\")\n",
1133
+ " #generate_sample_translations(model, device, sample_sentences_de_for_tracking)\n",
1134
+ "\n",
1135
+ " if bleu_score > best_bleu:\n",
1136
+ " best_bleu = bleu_score\n",
1137
+ " logger.info(f\" New best BLEU! Saving best model...\")\n",
1138
+ " # Save EVERYTHING so you can resume even from best model\n",
1139
+ " torch.save({\n",
1140
+ " 'global_step': global_step,\n",
1141
+ " 'model_state_dict': model.state_dict(),\n",
1142
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
1143
+ " 'scheduler_state_dict': scheduler.state_dict(),\n",
1144
+ " 'scaler_state_dict': scaler.state_dict(),\n",
1145
+ " 'best_bleu': best_bleu\n",
1146
+ " }, BEST_CHECKPOINT_PATH)\n",
1147
+ "\n",
1148
+ " model.train()\n",
1149
+ "\n",
1150
+ " progress_bar.close()\n",
1151
+ " writer.close()\n",
1152
+ "\n",
1153
+ " # Save Final (With States)\n",
1154
+ " torch.save({\n",
1155
+ " 'global_step': global_step,\n",
1156
+ " 'model_state_dict': model.state_dict(),\n",
1157
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
1158
+ " 'scheduler_state_dict': scheduler.state_dict(),\n",
1159
+ " 'scaler_state_dict': scaler.state_dict(),\n",
1160
+ " 'best_bleu': best_bleu\n",
1161
+ " }, LAST_CHECKPOINT_PATH)\n",
1162
+ "\n",
1163
+ " print(\"\\n\" + \"*\"*80)\n",
1164
+ " print(\" EXPERIMENT COMPLETE \")\n",
1165
+ " print(\"*\"*80)"
1166
+ ]
1167
+ },
1168
+ {
1169
+ "cell_type": "code",
1170
+ "execution_count": null,
1171
+ "metadata": {
1172
+ "id": "UsS6qhLtJaMF"
1173
+ },
1174
+ "outputs": [],
1175
+ "source": [
1176
+ "import os\n",
1177
+ "import sys\n",
1178
+ "import torch\n",
1179
+ "import transformers\n",
1180
+ "import datasets\n",
1181
+ "import torchmetrics\n",
1182
+ "import numpy\n",
1183
+ "import pkg_resources\n",
1184
+ "\n",
1185
+ "def log_environment_separate(log_dir):\n",
1186
+ " # Define the separate file path\n",
1187
+ " meta_file = os.path.join(log_dir, \"system_metadata.txt\")\n",
1188
+ "\n",
1189
+ " with open(meta_file, \"w\") as f:\n",
1190
+ " # --- PART 1: SUMMARY ---\n",
1191
+ " f.write(\"=\"*40 + \"\\n\")\n",
1192
+ " f.write(\"CORE ENVIRONMENT SUMMARY\\n\")\n",
1193
+ " f.write(\"=\"*40 + \"\\n\")\n",
1194
+ " f.write(f\"Python: {sys.version.split()[0]}\\n\")\n",
1195
+ " f.write(f\"PyTorch: {torch.__version__}\\n\")\n",
1196
+ " f.write(f\"Transformers: {transformers.__version__}\\n\")\n",
1197
+ " f.write(f\"Datasets: {datasets.__version__}\\n\")\n",
1198
+ " f.write(f\"TorchMetrics: {torchmetrics.__version__}\\n\")\n",
1199
+ " f.write(f\"NumPy: {numpy.__version__}\\n\")\n",
1200
+ "\n",
1201
+ " try:\n",
1202
+ " import sacrebleu\n",
1203
+ " f.write(f\"SacreBLEU: {sacrebleu.__version__}\\n\")\n",
1204
+ " except ImportError:\n",
1205
+ " f.write(\"SacreBLEU: Not Installed\\n\")\n",
1206
+ "\n",
1207
+ " if torch.cuda.is_available():\n",
1208
+ " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n",
1209
+ " f.write(f\"CUDA Ver: {torch.version.cuda}\\n\")\n",
1210
+ " f.write(f\"Capability: {torch.cuda.get_device_capability(0)}\\n\")\n",
1211
+ " else:\n",
1212
+ " f.write(\"GPU: None (CPU Only)\\n\")\n",
1213
+ "\n",
1214
+ " # --- PART 2: FULL FREEZE ---\n",
1215
+ " f.write(\"\\n\" + \"=\"*40 + \"\\n\")\n",
1216
+ " f.write(\"FULL LIBRARY DEPENDENCIES (PIP FREEZE)\\n\")\n",
1217
+ " f.write(\"=\"*40 + \"\\n\")\n",
1218
+ "\n",
1219
+ " installed_packages = {d.project_name: d.version for d in pkg_resources.working_set}\n",
1220
+ " for package, version in sorted(installed_packages.items()):\n",
1221
+ " f.write(f\"{package}=={version}\\n\")\n",
1222
+ "\n",
1223
+ " print(f\"✅ Environment details saved SEPARATELY to: {meta_file}\")\n",
1224
+ "\n",
1225
+ "# Execute\n",
1226
+ "# Assumes CURRENT_RUN_DIR is defined from your config\n",
1227
+ "log_environment_separate(CURRENT_RUN_DIR)"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "cell_type": "code",
1232
+ "execution_count": null,
1233
+ "metadata": {
1234
+ "id": "tqDiOyy18clU"
1235
+ },
1236
+ "outputs": [],
1237
+ "source": [
1238
+ "# TENSORBOARD VISUALIZATION\n",
1239
+ "\n",
1240
+ "%load_ext tensorboard\n",
1241
+ "\n",
1242
+ "TENSORBOARD_BASE_DIR = os.path.join(DRIVE_BASE_PATH)\n",
1243
+ "\n",
1244
+ "%tensorboard --logdir \"{TENSORBOARD_BASE_DIR}\""
1245
+ ]
1246
+ },
1247
+ {
1248
+ "cell_type": "code",
1249
+ "execution_count": null,
1250
+ "metadata": {
1251
+ "id": "AmOcgwNnJqOj"
1252
+ },
1253
+ "outputs": [],
1254
+ "source": [
1255
+ "from google.colab import runtime\n",
1256
+ "runtime.unassign()"
1257
+ ]
1258
+ },
1259
+ {
1260
+ "cell_type": "markdown",
1261
+ "metadata": {
1262
+ "id": "eI0-qVlWVVpx"
1263
+ },
1264
+ "source": [
1265
+ "## End"
1266
+ ]
1267
+ }
1268
+ ],
1269
+ "metadata": {
1270
+ "accelerator": "GPU",
1271
+ "colab": {
1272
+ "gpuType": "A100",
1273
+ "provenance": [],
1274
+ "machine_shape": "hm"
1275
+ },
1276
+ "kernelspec": {
1277
+ "display_name": "Python 3",
1278
+ "name": "python3"
1279
+ },
1280
+ "language_info": {
1281
+ "name": "python"
1282
+ }
1283
+ },
1284
+ "nbformat": 4,
1285
+ "nbformat_minor": 0
1286
+ }
Gated_PRISM_train_hybrid_RoPE.ipynb ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "f2CVny1CrxQc"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "!pip install -q torchmetrics sacrebleu x-transformers\n",
12
+ "\n",
13
+ "# ==============================================================================\n",
14
+ "# 1. CONFIGURATION\n",
15
+ "# ==============================================================================\n",
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.nn.functional as F\n",
20
+ "import torch.fft\n",
21
+ "from torch.utils.data import DataLoader\n",
22
+ "from transformers import AutoTokenizer, DataCollatorForSeq2Seq, get_cosine_schedule_with_warmup\n",
23
+ "from datasets import load_dataset\n",
24
+ "import math, sys, logging, datetime, json, random\n",
25
+ "import numpy as np\n",
26
+ "from tqdm.auto import tqdm\n",
27
+ "from torch.utils.tensorboard import SummaryWriter\n",
28
+ "from typing import List\n",
29
+ "\n",
30
+ "# --- Hardware Speedups ---\n",
31
+ "torch.set_float32_matmul_precision('medium')\n",
32
+ "\n",
33
+ "# --- Data & Task Size ---\n",
34
+ "MAX_LENGTH = 128\n",
35
+ "MODEL_CHOICE = \"Molecule_100k_1\"\n",
36
+ "\n",
37
+ "# --- Model Architecture Config ---\n",
38
+ "D_MODEL = 512\n",
39
+ "NUM_HEADS = 8\n",
40
+ "D_FF = 2048\n",
41
+ "DROPOUT = 0.1\n",
42
+ "NUM_ENCODER_LAYERS = 6 # PRISM LAYERS\n",
43
+ "NUM_REFINING_LAYERS = 0 #\n",
44
+ "NUM_DECODER_LAYERS = 6\n",
45
+ "\n",
46
+ "# --- Training Config ---\n",
47
+ "TARGET_TRAINING_STEPS = 100000\n",
48
+ "VALIDATION_SCHEDULE = [\n",
49
+ " 2000, 4000, 5000, 7500, 10000, 15000, 20000,\n",
50
+ " 25000, 30000, 35000, 42500, 50000, 57500, 65000, 72500, 90000, 100000\n",
51
+ "]\n",
52
+ "\n",
53
+ "PEAK_LEARNING_RATE = 8e-4\n",
54
+ "WARMUP_STEPS = 600\n",
55
+ "WEIGHT_DECAY = 0.01\n",
56
+ "LABEL_SMOOTHING_EPSILON = 0.1\n",
57
+ "\n",
58
+ "# --- Paths ---\n",
59
+ "DRIVE_BASE_PATH = \"/content/drive/MyDrive/PRISM\"\n",
60
+ "PREBATCHED_REPO_ID = \"prism-lab/wmt14-de-en-prebatched-w4\"\n",
61
+ "ORIGINAL_BUCKETED_REPO_ID = \"prism-lab/wmt14-de-en-bucketed-w4\"\n",
62
+ "MODEL_CHECKPOINT = \"Helsinki-NLP/opus-mt-de-en\"\n"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {
69
+ "id": "2VuaI43WDGoA"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "\n",
74
+ "# ==============================================================================\n",
75
+ "# 2. IMPORTS & SETUP\n",
76
+ "# ==============================================================================\n",
77
+ "from x_transformers import Decoder\n",
78
+ "\n",
79
+ "def set_seed(seed_value=116):\n",
80
+ " random.seed(seed_value)\n",
81
+ " np.random.seed(seed_value)\n",
82
+ " torch.manual_seed(seed_value)\n",
83
+ " if torch.cuda.is_available():\n",
84
+ " torch.cuda.manual_seed_all(seed_value)\n",
85
+ " torch.backends.cudnn.deterministic = True\n",
86
+ " torch.backends.cudnn.benchmark = True\n",
87
+ "\n",
88
+ "set_seed()\n",
89
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
90
+ "\n",
91
+ "# --- Logging Setup ---\n",
92
+ "experiment_name = f\"{MODEL_CHOICE}_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}\"\n",
93
+ "CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name)\n",
94
+ "SAVE_DIR = os.path.join(CURRENT_RUN_DIR, \"models\")\n",
95
+ "LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, \"tensorboard_logs\")\n",
96
+ "LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, \"run_log.txt\")\n",
97
+ "\n",
98
+ "os.makedirs(SAVE_DIR, exist_ok=True)\n",
99
+ "os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)\n",
100
+ "\n",
101
+ "logging.basicConfig(\n",
102
+ " level=logging.INFO,\n",
103
+ " format='%(asctime)s | %(message)s',\n",
104
+ " handlers=[logging.FileHandler(LOG_FILE_TXT), logging.StreamHandler(sys.stdout)],\n",
105
+ " force=True\n",
106
+ ")\n",
107
+ "logger = logging.getLogger(__name__)\n",
108
+ "writer = SummaryWriter(LOG_DIR_TENSORBOARD)\n",
109
+ "\n",
110
+ "# ==============================================================================\n",
111
+ "# 3. DATA LOADING\n",
112
+ "# ==============================================================================\n",
113
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n",
114
+ "VOCAB_SIZE = len(tokenizer)\n",
115
+ "standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)\n",
116
+ "\n",
117
+ "class PreBatchedCollator:\n",
118
+ " def __init__(self, original_dataset_split):\n",
119
+ " self.original_dataset = original_dataset_split\n",
120
+ " def __call__(self, features: List[dict]) -> dict:\n",
121
+ " batch_indices = features[0]['batch_indices']\n",
122
+ " dict_of_lists = self.original_dataset[batch_indices]\n",
123
+ " list_of_dicts = []\n",
124
+ " keys = dict_of_lists.keys()\n",
125
+ " num_samples = len(dict_of_lists['input_ids'])\n",
126
+ " for i in range(num_samples):\n",
127
+ " list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})\n",
128
+ " return standard_collator(list_of_dicts)\n",
129
+ "\n",
130
+ "logger.info(f\"Loading datasets...\")\n",
131
+ "prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)\n",
132
+ "original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)\n",
133
+ "train_collator = PreBatchedCollator(original_datasets[\"train\"])\n",
134
+ "\n",
135
+ "train_dataloader = DataLoader(\n",
136
+ " prebatched_datasets[\"train\"], batch_size=1, shuffle=True,\n",
137
+ " collate_fn=train_collator, num_workers=2, pin_memory=True, prefetch_factor=2\n",
138
+ ")\n",
139
+ "val_dataloader = DataLoader(\n",
140
+ " original_datasets[\"validation\"], batch_size=64,\n",
141
+ " collate_fn=standard_collator, num_workers=2\n",
142
+ ")\n",
143
+ "# ==============================================================================\n",
144
+ "# 4. PRISM ARCHITECTURE (FIXED: COMPLEX DROPOUT & PADDING)\n",
145
+ "# ==============================================================================\n",
146
+ "\n",
147
+ "# ==============================================================================\n",
148
+ "# 4. PRISM ARCHITECTURE (CLEAN & CORRECTED)\n",
149
+ "# ==============================================================================\n",
150
+ "\n",
151
+ "class ComplexDropout(nn.Module):\n",
152
+ " def __init__(self, p=0.5):\n",
153
+ " super().__init__()\n",
154
+ " self.p = p\n",
155
+ "\n",
156
+ " def forward(self, z):\n",
157
+ " if not self.training or self.p == 0.0:\n",
158
+ " return z\n",
159
+ " mask = torch.ones_like(z.real)\n",
160
+ " mask = F.dropout(mask, self.p, self.training, inplace=False)\n",
161
+ " return z * mask\n",
162
+ "\n",
163
+ "class PhasePreservingLayerNorm(nn.Module):\n",
164
+ " def __init__(self, d_model, eps=1e-5):\n",
165
+ " super().__init__()\n",
166
+ " self.layernorm = nn.LayerNorm(d_model, eps=eps)\n",
167
+ " self.eps = eps\n",
168
+ "\n",
169
+ " def forward(self, x):\n",
170
+ " mag = torch.abs(x)\n",
171
+ " mag_norm = self.layernorm(mag)\n",
172
+ " return mag_norm.to(x.dtype) * (x / (mag + self.eps))\n",
173
+ "\n",
174
+ "class HarmonicEmbedding(nn.Module):\n",
175
+ " def __init__(self, num_embeddings, embedding_dim, max_period=10000.0):\n",
176
+ " super().__init__()\n",
177
+ " self.embedding_dim = embedding_dim\n",
178
+ " self.complex_embedding = nn.Embedding(num_embeddings, embedding_dim * 2)\n",
179
+ " freqs = torch.exp(torch.arange(0, embedding_dim, dtype=torch.float32) * -(math.log(max_period) / embedding_dim))\n",
180
+ " self.register_buffer('freqs', freqs)\n",
181
+ "\n",
182
+ " def forward(self, input_ids):\n",
183
+ " raw_embeds = self.complex_embedding(input_ids)\n",
184
+ " real = raw_embeds[..., :self.embedding_dim]\n",
185
+ " imag = raw_embeds[..., self.embedding_dim:]\n",
186
+ " content_z = torch.complex(real, imag)\n",
187
+ " seq_len = input_ids.shape[1]\n",
188
+ " positions = torch.arange(seq_len, device=input_ids.device).float()\n",
189
+ " angles = torch.outer(positions, self.freqs)\n",
190
+ " pos_rotation = torch.polar(torch.ones_like(angles), angles).unsqueeze(0)\n",
191
+ " return content_z * pos_rotation\n",
192
+ "\n",
193
+ "class ModReLU(nn.Module):\n",
194
+ " def __init__(self, features):\n",
195
+ " super().__init__()\n",
196
+ " self.b = nn.Parameter(torch.zeros(features))\n",
197
+ " def forward(self, z):\n",
198
+ " mag = torch.abs(z)\n",
199
+ " new_mag = F.relu(mag + self.b)\n",
200
+ " phase = z / (mag + 1e-6)\n",
201
+ " return new_mag * phase\n",
202
+ "\n",
203
+ "# --- THE CORRECT LAYER (Cartesian Gated) ---\n",
204
+ "class PRISMLayer(nn.Module):\n",
205
+ " def __init__(self, d_model, max_len=5000, dropout=0.1):\n",
206
+ " super().__init__()\n",
207
+ " self.d_model = d_model\n",
208
+ " self.filter_len = max_len\n",
209
+ "\n",
210
+ " # 1. THE GATE (Data Dependency)\n",
211
+ " self.gate_proj = nn.Linear(d_model * 2, d_model * 2)\n",
212
+ "\n",
213
+ " # 2. THE FILTER (Global Pattern)\n",
214
+ " self.global_filter = nn.Parameter(torch.randn(d_model, max_len, dtype=torch.cfloat) * 0.02)\n",
215
+ "\n",
216
+ " # 3. INPUT MIXING\n",
217
+ " self.mix_real = nn.Linear(d_model, d_model)\n",
218
+ " self.mix_imag = nn.Linear(d_model, d_model)\n",
219
+ "\n",
220
+ " # 4. OUTPUT PROJECTION\n",
221
+ " self.out_real = nn.Linear(d_model, d_model)\n",
222
+ " self.out_imag = nn.Linear(d_model, d_model)\n",
223
+ "\n",
224
+ " self.activation = ModReLU(d_model)\n",
225
+ " self.norm = PhasePreservingLayerNorm(d_model)\n",
226
+ " self.dropout = ComplexDropout(dropout)\n",
227
+ "\n",
228
+ " def complex_linear(self, x, l_real, l_imag):\n",
229
+ " r, i = x.real, x.imag\n",
230
+ " new_r = l_real(r) - l_imag(i)\n",
231
+ " new_i = l_real(i) + l_imag(r)\n",
232
+ " return torch.complex(new_r, new_i)\n",
233
+ "\n",
234
+ " def forward(self, x, src_mask=None):\n",
235
+ " if x is None: return None\n",
236
+ " residual = x\n",
237
+ " x_norm = self.norm(x)\n",
238
+ "\n",
239
+ " if src_mask is not None:\n",
240
+ " x_norm = x_norm.masked_fill(src_mask.unsqueeze(-1), 0.0)\n",
241
+ "\n",
242
+ " # A. GATE\n",
243
+ " x_cat = torch.cat([x_norm.real, x_norm.imag], dim=-1)\n",
244
+ " gates = torch.sigmoid(self.gate_proj(x_cat))\n",
245
+ " gate_r, gate_i = gates.chunk(2, dim=-1)\n",
246
+ "\n",
247
+ " # B. FILTER\n",
248
+ " B, L, D = x_norm.shape\n",
249
+ " x_freq = torch.fft.fft(x_norm, n=self.filter_len, dim=1)\n",
250
+ " x_filtered = x_freq * self.global_filter.transpose(-1, -2)\n",
251
+ " x_time = torch.fft.ifft(x_filtered, n=self.filter_len, dim=1)\n",
252
+ " x_time = x_time[:, :L, :]\n",
253
+ "\n",
254
+ " # C. APPLY GATE\n",
255
+ " gated_r = x_time.real * gate_r\n",
256
+ " gated_i = x_time.imag * gate_i\n",
257
+ " x_gated = torch.complex(gated_r, gated_i)\n",
258
+ "\n",
259
+ " # D. OUT\n",
260
+ " x_mixed = self.complex_linear(x_gated, self.mix_real, self.mix_imag)\n",
261
+ " x_act = self.activation(x_mixed)\n",
262
+ " out = self.complex_linear(x_act, self.out_real, self.out_imag)\n",
263
+ " return self.dropout(out) + residual\n",
264
+ "\n",
265
+ "# --- ENCODER MUST BE DEFINED AFTER LAYER ---\n",
266
+ "class PRISMEncoder(nn.Module):\n",
267
+ " def __init__(self, num_layers, d_model, max_len, dropout=0.1):\n",
268
+ " super().__init__()\n",
269
+ " self.layers = nn.ModuleList([PRISMLayer(d_model, max_len, dropout) for _ in range(num_layers)])\n",
270
+ " self.final_norm = PhasePreservingLayerNorm(d_model)\n",
271
+ "\n",
272
+ " def forward(self, x, src_mask=None):\n",
273
+ " for layer in self.layers:\n",
274
+ " x = layer(x, src_mask)\n",
275
+ " return self.final_norm(x)\n",
276
+ "\n",
277
+ "# --- THE CORRECT BRIDGE (Cartesian) ---\n",
278
+ "class ComplexToRealBridge(nn.Module):\n",
279
+ " def __init__(self, d_model):\n",
280
+ " super().__init__()\n",
281
+ " self.proj = nn.Linear(d_model * 2, d_model)\n",
282
+ " self.norm = nn.LayerNorm(d_model)\n",
283
+ "\n",
284
+ " def forward(self, x_complex):\n",
285
+ " if x_complex is None: raise ValueError(\"Bridge None\")\n",
286
+ " cat = torch.cat([x_complex.real, x_complex.imag], dim=-1)\n",
287
+ " return self.norm(self.proj(cat))\n",
288
+ "\n",
289
+ "class PRISMHybrid_RoPE(nn.Module):\n",
290
+ " def __init__(self, num_encoder_layers, num_refining_layers, num_decoder_layers,\n",
291
+ " num_heads, d_model, dff, vocab_size, max_length, dropout):\n",
292
+ " super().__init__()\n",
293
+ " self.d_model = d_model\n",
294
+ " self.harmonic_embedding = HarmonicEmbedding(vocab_size, d_model)\n",
295
+ " self.tgt_embedding = nn.Embedding(vocab_size, d_model)\n",
296
+ " self.dropout = nn.Dropout(dropout)\n",
297
+ "\n",
298
+ " if num_encoder_layers > 0:\n",
299
+ " self.prism_encoder = PRISMEncoder(num_encoder_layers, d_model, max_length, dropout)\n",
300
+ " else:\n",
301
+ " self.prism_encoder = None\n",
302
+ "\n",
303
+ " self.bridge = ComplexToRealBridge(d_model)\n",
304
+ "\n",
305
+ " if num_refining_layers > 0:\n",
306
+ " refining_layer = nn.TransformerEncoderLayer(\n",
307
+ " d_model, num_heads, dff, dropout,\n",
308
+ " batch_first=True, norm_first=True\n",
309
+ " )\n",
310
+ " self.reasoning_encoder = nn.TransformerEncoder(refining_layer, num_layers=num_refining_layers)\n",
311
+ " else:\n",
312
+ " self.reasoning_encoder = None\n",
313
+ "\n",
314
+ " self.decoder = Decoder(\n",
315
+ " dim = d_model, depth = num_decoder_layers, heads = num_heads, attn_dim_head = d_model // num_heads,\n",
316
+ " ff_mult = dff / d_model, rotary_pos_emb = True, cross_attend = True, attn_flash = True,\n",
317
+ " attn_dropout = dropout, ff_dropout = dropout, use_rmsnorm = True\n",
318
+ " )\n",
319
+ " self.final_linear = nn.Linear(d_model, vocab_size)\n",
320
+ " self.final_linear.weight = self.tgt_embedding.weight\n",
321
+ "\n",
322
+ " def create_masks(self, src, tgt):\n",
323
+ " src_padding_mask = (src == tokenizer.pad_token_id)\n",
324
+ " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n",
325
+ " tgt_mask = nn.Transformer.generate_square_subsequent_mask(sz=tgt.size(1), device=src.device, dtype=torch.bool)\n",
326
+ " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n",
327
+ "\n",
328
+ " def forward(self, src, tgt, src_mask, tgt_pad, mem_pad, tgt_mask):\n",
329
+ " src_harmonic = self.harmonic_embedding(src)\n",
330
+ " if src_mask is not None:\n",
331
+ " src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)\n",
332
+ "\n",
333
+ " if self.prism_encoder is not None:\n",
334
+ " if self.training:\n",
335
+ " src_harmonic.requires_grad_(True)\n",
336
+ " encoded_complex = torch.utils.checkpoint.checkpoint(\n",
337
+ " self.prism_encoder.forward, # Safest\n",
338
+ " src_harmonic, src_mask, use_reentrant=False\n",
339
+ " )\n",
340
+ " else:\n",
341
+ " encoded_complex = self.prism_encoder(src_harmonic, src_mask)\n",
342
+ " else:\n",
343
+ " encoded_complex = src_harmonic\n",
344
+ "\n",
345
+ " coarse_memory = self.bridge(encoded_complex)\n",
346
+ " if self.reasoning_encoder is not None:\n",
347
+ " refined_memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=mem_pad)\n",
348
+ " else:\n",
349
+ " refined_memory = coarse_memory\n",
350
+ "\n",
351
+ " tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)\n",
352
+ " tgt_emb = self.dropout(tgt_emb)\n",
353
+ " context_mask = ~mem_pad if mem_pad is not None else None\n",
354
+ " decoder_mask = ~tgt_pad if tgt_pad is not None else None\n",
355
+ "\n",
356
+ " if self.training:\n",
357
+ " tgt_emb.requires_grad_(True)\n",
358
+ " output = torch.utils.checkpoint.checkpoint(\n",
359
+ " self.decoder, tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask, use_reentrant=False\n",
360
+ " )\n",
361
+ " else:\n",
362
+ " output = self.decoder(tgt_emb, context=refined_memory, mask=decoder_mask, context_mask=context_mask)\n",
363
+ "\n",
364
+ " return self.final_linear(output)\n",
365
+ "\n",
366
+ " # ... (generate function remains the same) ...\n",
367
+ " @torch.no_grad()\n",
368
+ " def generate(self, src, max_length, num_beams=5):\n",
369
+ " self.eval()\n",
370
+ " src_mask = (src == tokenizer.pad_token_id)\n",
371
+ " context_mask = ~src_mask\n",
372
+ " src_harmonic = self.harmonic_embedding(src)\n",
373
+ " if src_mask is not None:\n",
374
+ " src_harmonic = src_harmonic.masked_fill(src_mask.unsqueeze(-1), 0.0)\n",
375
+ "\n",
376
+ " if self.prism_encoder is not None:\n",
377
+ " encoded_complex = self.prism_encoder(src_harmonic, src_mask)\n",
378
+ " else:\n",
379
+ " encoded_complex = src_harmonic\n",
380
+ "\n",
381
+ " coarse_memory = self.bridge(encoded_complex)\n",
382
+ "\n",
383
+ " if self.reasoning_encoder is not None:\n",
384
+ " memory = self.reasoning_encoder(coarse_memory, src_key_padding_mask=src_mask)\n",
385
+ " else:\n",
386
+ " memory = coarse_memory\n",
387
+ "\n",
388
+ " batch_size = src.shape[0]\n",
389
+ " memory = memory.repeat_interleave(num_beams, dim=0)\n",
390
+ " context_mask = context_mask.repeat_interleave(num_beams, dim=0)\n",
391
+ "\n",
392
+ " beams = torch.full((batch_size * num_beams, 1), tokenizer.pad_token_id, dtype=torch.long, device=src.device)\n",
393
+ " beam_scores = torch.zeros(batch_size * num_beams, device=src.device)\n",
394
+ " finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)\n",
395
+ "\n",
396
+ " for _ in range(max_length - 1):\n",
397
+ " if finished_beams.all(): break\n",
398
+ " tgt_emb = self.tgt_embedding(beams) * math.sqrt(self.d_model)\n",
399
+ " tgt_emb = self.dropout(tgt_emb)\n",
400
+ "\n",
401
+ " # Decoder\n",
402
+ " decoder_output = self.decoder(tgt_emb, context=memory, context_mask=context_mask)\n",
403
+ " logits = self.final_linear(decoder_output[:, -1, :])\n",
404
+ " log_probs = F.log_softmax(logits, dim=-1)\n",
405
+ "\n",
406
+ " # Masking\n",
407
+ " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n",
408
+ " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n",
409
+ "\n",
410
+ " # --- BEAM SEARCH LOGIC FIX ---\n",
411
+ " if _ == 0:\n",
412
+ " # First Step: Expand from the first beam only (since all are identical start tokens)\n",
413
+ " # Reshape to (batch, beams, vocab)\n",
414
+ " total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, num_beams, -1)\n",
415
+ " # Mask out all beams except the first one (-inf)\n",
416
+ " total[:, 1:, :] = -torch.inf\n",
417
+ " # Flatten back to (batch, beams*vocab) to pick top k\n",
418
+ " total = total.view(batch_size, -1)\n",
419
+ " else:\n",
420
+ " # Subsequent Steps: Standard Flatten\n",
421
+ " total = (beam_scores.unsqueeze(1) + log_probs).view(batch_size, -1)\n",
422
+ "\n",
423
+ " top_scores, top_indices = torch.topk(total, k=num_beams, dim=1)\n",
424
+ "\n",
425
+ " beam_indices = top_indices // log_probs.shape[-1]\n",
426
+ " token_indices = top_indices % log_probs.shape[-1]\n",
427
+ "\n",
428
+ " # Now dimensions match: (batch_size, 1) + (batch_size, k)\n",
429
+ " effective = (torch.arange(batch_size, device=src.device).unsqueeze(1) * num_beams + beam_indices).view(-1)\n",
430
+ " beams = torch.cat([beams[effective], token_indices.view(-1, 1)], dim=1)\n",
431
+ " beam_scores = top_scores.view(-1)\n",
432
+ " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n",
433
+ "\n",
434
+ " final_beams = beams.view(batch_size, num_beams, -1)\n",
435
+ " best_beams = final_beams[:, 0, :]\n",
436
+ " self.train()\n",
437
+ " return best_beams"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "code",
442
+ "execution_count": null,
443
+ "metadata": {
444
+ "id": "NFiIvRiyDg8K"
445
+ },
446
+ "outputs": [],
447
+ "source": [
448
+ "from torchmetrics.text import SacreBLEUScore\n",
449
+ "\n",
450
+ "def evaluate(model, dataloader, device):\n",
451
+ " # Use SacreBLEUScore (defaults to '13a' tokenizer, the WMT standard)\n",
452
+ " metric = SacreBLEUScore().to(device)\n",
453
+ "\n",
454
+ " model.eval()\n",
455
+ "\n",
456
+ " # Use no_grad to save memory and speed up validation\n",
457
+ " with torch.no_grad():\n",
458
+ " for batch in tqdm(dataloader, desc=\"Evaluating\", leave=False):\n",
459
+ " input_ids = batch['input_ids'].to(device)\n",
460
+ " labels = batch['labels']\n",
461
+ "\n",
462
+ " # Generate predictions\n",
463
+ " generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n",
464
+ "\n",
465
+ " # Decode predictions\n",
466
+ " pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
467
+ "\n",
468
+ " # Decode labels (Fixing -100 padding)\n",
469
+ " labels[labels == -100] = tokenizer.pad_token_id\n",
470
+ " ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
471
+ "\n",
472
+ " # Update Metric\n",
473
+ " # SacreBLEU expects references as a list of lists: [[ref1], [ref2], ...]\n",
474
+ " formatted_refs = [[ref] for ref in ref_texts]\n",
475
+ " metric.update(pred_texts, formatted_refs)\n",
476
+ "\n",
477
+ " model.train()\n",
478
+ "\n",
479
+ " # Compute returns a tensor, .item() converts it to a standard python float\n",
480
+ " return metric.compute().item()\n"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "metadata": {
487
+ "id": "8M_uQNvKolVF"
488
+ },
489
+ "outputs": [],
490
+ "source": [
491
+ "\n",
492
+ "# ==============================================================================\n",
493
+ "# 5. TRAINING LOOP\n",
494
+ "# ==============================================================================\n",
495
+ "if __name__ == '__main__':\n",
496
+ " experiment_name = f\"PRISM_Hybrid_RoPE_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}\"\n",
497
+ " config_state = {\"model\": MODEL_CHOICE, \"d_model\": D_MODEL, \"layers\": NUM_ENCODER_LAYERS,\n",
498
+ " \"lr\": PEAK_LEARNING_RATE, \"seed\": 116}\n",
499
+ "\n",
500
+ " logger.info(\"Initializing PRISM...\")\n",
501
+ "\n",
502
+ " model = PRISMHybrid_RoPE(\n",
503
+ " num_encoder_layers=NUM_ENCODER_LAYERS,\n",
504
+ " num_refining_layers=NUM_REFINING_LAYERS,\n",
505
+ " num_decoder_layers=NUM_DECODER_LAYERS,\n",
506
+ " num_heads=NUM_HEADS,\n",
507
+ " d_model=D_MODEL,\n",
508
+ " dff=D_FF,\n",
509
+ " vocab_size=VOCAB_SIZE,\n",
510
+ " max_length=MAX_LENGTH,\n",
511
+ " dropout=DROPOUT\n",
512
+ " )\n",
513
+ "\n",
514
+ " model.to(device)\n",
515
+ " print(model)\n",
516
+ " # FIX: Robust Initialization that respects the Gate Bias\n",
517
+ "\n",
518
+ " def init_weights_PRISM(m):\n",
519
+ " # 1. Linear Layers (Gate, Mix, Output)\n",
520
+ " if isinstance(m, nn.Linear):\n",
521
+ " nn.init.xavier_uniform_(m.weight)\n",
522
+ " if m.bias is not None:\n",
523
+ " nn.init.zeros_(m.bias)\n",
524
+ "\n",
525
+ " # 2. Embeddings\n",
526
+ " # Keep this! It helps the complex signal stay strong against noise.\n",
527
+ " elif isinstance(m, nn.Embedding):\n",
528
+ " std = 1.0 / math.sqrt(D_MODEL)\n",
529
+ " nn.init.normal_(m.weight, mean=0.0, std=std)\n",
530
+ " if m.padding_idx is not None:\n",
531
+ " nn.init.constant_(m.weight[m.padding_idx], 0.0)\n",
532
+ "\n",
533
+ " # 3. Global Filter\n",
534
+ " elif hasattr(m, 'global_filter'):\n",
535
+ " nn.init.normal_(m.global_filter, mean=0.0, std=0.02)\n",
536
+ "\n",
537
+ "\n",
538
+ " # A. Apply the generic initialization\n",
539
+ " model.apply(init_weights_PRISM)\n",
540
+ "\n",
541
+ "\n",
542
+ " logger.info(\"✅ Initialization Complete.\")\n",
543
+ "\n",
544
+ "\n",
545
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
546
+ " scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, TARGET_TRAINING_STEPS)\n",
547
+ " loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=LABEL_SMOOTHING_EPSILON)\n",
548
+ "\n",
549
+ " logger.info(f\"STARTING MARATHON ({TARGET_TRAINING_STEPS} steps)\")\n",
550
+ " model.train()\n",
551
+ " global_step = 0\n",
552
+ " best_bleu = 0.0\n",
553
+ " progress = tqdm(total=TARGET_TRAINING_STEPS)\n",
554
+ "\n",
555
+ " while global_step < TARGET_TRAINING_STEPS:\n",
556
+ " for batch in train_dataloader:\n",
557
+ " if global_step >= TARGET_TRAINING_STEPS: break\n",
558
+ " optimizer.zero_grad()\n",
559
+ " input_ids = batch['input_ids'].to(device, non_blocking=True)\n",
560
+ " labels = batch['labels'].to(device, non_blocking=True)\n",
561
+ "\n",
562
+ " dec_in = torch.cat([torch.full((labels.size(0), 1), tokenizer.pad_token_id, device=device), labels[:, :-1]], dim=1)\n",
563
+ " dec_in[dec_in == -100] = tokenizer.pad_token_id\n",
564
+ "\n",
565
+ " src_mask, tgt_pad, mem_pad, tgt_mask = model.create_masks(input_ids, dec_in)\n",
566
+ " tgt_pad[:, 0] = False\n",
567
+ "\n",
568
+ " out = model(input_ids, dec_in, src_mask, tgt_pad, mem_pad, tgt_mask)\n",
569
+ " loss = loss_fn(out.view(-1, VOCAB_SIZE), labels.view(-1))\n",
570
+ "\n",
571
+ " loss.backward()\n",
572
+ "\n",
573
+ " # --- MODIFICATION START ---\n",
574
+ " # clip_grad_norm_ returns the norm calculated BEFORE clipping\n",
575
+ " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
576
+ " # --- MODIFICATION END ---\n",
577
+ "\n",
578
+ " optimizer.step()\n",
579
+ " scheduler.step()\n",
580
+ " global_step += 1\n",
581
+ " progress.update(1)\n",
582
+ "\n",
583
+ " if global_step % 50 == 0:\n",
584
+ " writer.add_scalar('train/loss', loss.item(), global_step)\n",
585
+ " writer.add_scalar('train/grad_norm', grad_norm.item(), global_step) # Log to TensorBoard\n",
586
+ "\n",
587
+ " # Add 'gnorm' to the progress bar (formatted to 2 decimal places)\n",
588
+ " progress.set_postfix(loss=loss.item(), gnorm=f\"{grad_norm.item():.2f}\")\n",
589
+ "\n",
590
+ " if global_step in VALIDATION_SCHEDULE:\n",
591
+ " logger.info(f\"Validating at step {global_step}...\")\n",
592
+ " current_bleu = evaluate(model, val_dataloader, device)\n",
593
+ " writer.add_scalar('val/bleu', current_bleu, global_step)\n",
594
+ " logger.info(f\"Step {global_step} | BLEU: {current_bleu:.4f}\")\n",
595
+ " if current_bleu > best_bleu:\n",
596
+ " best_bleu = current_bleu\n",
597
+ " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"best_model.pt\"))\n",
598
+ "\n",
599
+ " torch.save(model.state_dict(), os.path.join(SAVE_DIR, \"marathon_model.pt\"))\n",
600
+ " logger.info(f\"Marathon Complete. Best BLEU: {best_bleu:.4f}\")"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": null,
606
+ "metadata": {
607
+ "id": "0xsWfDkeWp5-"
608
+ },
609
+ "outputs": [],
610
+ "source": [
611
+ "import os\n",
612
+ "import sys\n",
613
+ "import torch\n",
614
+ "import transformers\n",
615
+ "import datasets\n",
616
+ "import torchmetrics\n",
617
+ "import numpy\n",
618
+ "import pkg_resources\n",
619
+ "\n",
620
+ "def log_environment_separate(log_dir):\n",
621
+ " # Define the separate file path\n",
622
+ " meta_file = os.path.join(log_dir, \"system_metadata.txt\")\n",
623
+ "\n",
624
+ " with open(meta_file, \"w\") as f:\n",
625
+ " # --- PART 1: SUMMARY ---\n",
626
+ " f.write(\"=\"*40 + \"\\n\")\n",
627
+ " f.write(\"CORE ENVIRONMENT SUMMARY\\n\")\n",
628
+ " f.write(\"=\"*40 + \"\\n\")\n",
629
+ " f.write(f\"Python: {sys.version.split()[0]}\\n\")\n",
630
+ " f.write(f\"PyTorch: {torch.__version__}\\n\")\n",
631
+ " f.write(f\"Transformers: {transformers.__version__}\\n\")\n",
632
+ " f.write(f\"Datasets: {datasets.__version__}\\n\")\n",
633
+ " f.write(f\"TorchMetrics: {torchmetrics.__version__}\\n\")\n",
634
+ " f.write(f\"NumPy: {numpy.__version__}\\n\")\n",
635
+ "\n",
636
+ " try:\n",
637
+ " import sacrebleu\n",
638
+ " f.write(f\"SacreBLEU: {sacrebleu.__version__}\\n\")\n",
639
+ " except ImportError:\n",
640
+ " f.write(\"SacreBLEU: Not Installed\\n\")\n",
641
+ "\n",
642
+ " if torch.cuda.is_available():\n",
643
+ " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n",
644
+ " f.write(f\"CUDA Ver: {torch.version.cuda}\\n\")\n",
645
+ " f.write(f\"Capability: {torch.cuda.get_device_capability(0)}\\n\")\n",
646
+ " else:\n",
647
+ " f.write(\"GPU: None (CPU Only)\\n\")\n",
648
+ "\n",
649
+ " # --- PART 2: FULL FREEZE ---\n",
650
+ " f.write(\"\\n\" + \"=\"*40 + \"\\n\")\n",
651
+ " f.write(\"FULL LIBRARY DEPENDENCIES (PIP FREEZE)\\n\")\n",
652
+ " f.write(\"=\"*40 + \"\\n\")\n",
653
+ "\n",
654
+ " installed_packages = {d.project_name: d.version for d in pkg_resources.working_set}\n",
655
+ " for package, version in sorted(installed_packages.items()):\n",
656
+ " f.write(f\"{package}=={version}\\n\")\n",
657
+ "\n",
658
+ " print(f\"✅ Environment details saved SEPARATELY to: {meta_file}\")\n",
659
+ "\n",
660
+ "# Execute\n",
661
+ "# Assumes CURRENT_RUN_DIR is defined from your config\n",
662
+ "log_environment_separate(CURRENT_RUN_DIR)"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "source": [
668
+ "from google.colab import runtime\n",
669
+ "runtime.unassign()"
670
+ ],
671
+ "metadata": {
672
+ "id": "w7bFIVCLfCdT"
673
+ },
674
+ "execution_count": null,
675
+ "outputs": []
676
+ }
677
+ ],
678
+ "metadata": {
679
+ "accelerator": "GPU",
680
+ "colab": {
681
+ "gpuType": "A100",
682
+ "provenance": []
683
+ },
684
+ "kernelspec": {
685
+ "display_name": "Python 3",
686
+ "name": "python3"
687
+ },
688
+ "language_info": {
689
+ "name": "python"
690
+ }
691
+ },
692
+ "nbformat": 4,
693
+ "nbformat_minor": 0
694
+ }