{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "2s48Vmoo9EB5" }, "outputs": [], "source": [ "!pip install -q torchmetrics sacrebleu x-transformers" ] }, { "cell_type": "markdown", "metadata": { "id": "Lz8buKsjvA_w" }, "source": [ "## CONFIG" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "df355sdDrNSb" }, "outputs": [], "source": [ "!pip install -q torchmetrics sacrebleu x-transformers\n", "\n", "## CONFIG\n", "\n", "# --- Data & Task Size ---\n", "MAX_LENGTH = 128\n", "\n", "MODEL_CHOICE = \"Name_Your_Model\" # Renamed for clarity\n", "\n", "# --- Model Architecture Config ---\n", "D_MODEL = 512\n", "NUM_HEADS = 8\n", "D_FF = 2048\n", "DROPOUT = 0.1\n", "\n", "# --- Layer counts ---\n", "NUM_ENCODER_LAYERS = 7\n", "NUM_DECODER_LAYERS = 6\n", "\n", "# --- Training Config (ADJUSTED FOR FAIR COMPARISON) ---\n", "\n", "TARGET_TRAINING_STEPS = 100000\n", "GRAD_ACCUMULATION_STEPS = 2\n", "\n", "\n", "VALIDATION_SCHEDULE = [\n", " 2000, 4000, 5000, 7500, 10000, 15000, 20000,\n", " 25000, 30000, 35000, 42500, 50000, 57500, 65000, 72500, 90000, 100000\n", "]\n", "PEAK_LEARNING_RATE = 6e-4\n", "WARMUP_STEPS = 600 # Warmup can stay similar or scale slightly, 600 is fine\n", "WEIGHT_DECAY = 0.01\n", "\n", "# --- Regularization Config ---\n", "LABEL_SMOOTHING_EPSILON = 0.1\n", "\n", "# --- Other Constants ---\n", "DRIVE_BASE_PATH = \"/content/drive/MyDrive/AIAYN\"\n", "ORIGINAL_BUCKETED_REPO_ID = \"prism-lab/wmt14-de-en-bucketed-w4\" # Use the bucketed one (we will ignore buckets)\n", "MODEL_CHECKPOINT = \"Helsinki-NLP/opus-mt-de-en\"" ] }, { "cell_type": "markdown", "metadata": { "id": "W5l1HHRFXxPA" }, "source": [ "## DATALOADERS" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "FA5SqFzeMrpK" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from transformers import AutoTokenizer\n", "from datasets import load_dataset\n", "import math\n", "import os\n", "from tqdm.auto import tqdm\n", "from torch.utils.tensorboard import SummaryWriter\n", "import random\n", "import numpy as np\n", "import torch\n", "from transformers import get_cosine_schedule_with_warmup\n", "from typing import List\n", "from transformers import AutoModel\n", "from transformers import DataCollatorForSeq2Seq\n", "\n", "\n", "def set_seed(seed_value=5):\n", " \"\"\"Sets the seed for reproducibility.\"\"\"\n", " random.seed(seed_value)\n", " np.random.seed(seed_value)\n", " torch.manual_seed(seed_value)\n", " torch.cuda.manual_seed_all(seed_value)\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", "\n", "SEED = 117\n", "set_seed(SEED)\n", "print(f\"Reproducibility seed set to {SEED}\")\n", "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", "\n", "#torch.use_deterministic_algorithms(True)\n", "\n", "print(\"--- Loading Modernized Configuration ---\")\n", "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)\n", "\n", "torch.set_float32_matmul_precision('high')\n", "print(\"āœ… PyTorch matmul precision set to 'high'\")\n", "\n", "# --- Device Setup ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n", "\n", "VOCAB_SIZE = len(tokenizer)\n", "print(f\"Vocab size: {VOCAB_SIZE}\")\n", "\n", "\n", "# DATA LOADING & PREPARATION\n", "\n", "# --- 1. DEFINE THE FNET COLLATOR (FORCE FIXED LENGTH) ---\n", "# This is crucial. It forces every sentence to be exactly 128 tokens.\n", "fnet_collator = DataCollatorForSeq2Seq(\n", " tokenizer=tokenizer,\n", " padding=\"max_length\", # <--- FORCE PADDING\n", " max_length=MAX_LENGTH, # <--- 128 (defined in your config)\n", " pad_to_multiple_of=None\n", ")\n", "\n", "# --- 2. LOAD DATASET ---\n", "print(f\"Loading original bucketed samples from: {ORIGINAL_BUCKETED_REPO_ID}\")\n", "original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)\n", "\n", "# --- 3. CREATE DATALOADERS (STANDARD FIXED SIZE) ---\n", "FNET_PHYSICAL_BATCH_SIZE = 320\n", "\n", "g = torch.Generator()\n", "g.manual_seed(SEED)\n", "\n", "train_dataloader = DataLoader(\n", " original_datasets[\"train\"],\n", " batch_size=FNET_PHYSICAL_BATCH_SIZE, # <--- FIXED BATCH SIZE (Safe from OOM)\n", " shuffle=True, # <--- GLOBAL SHUFFLE\n", " num_workers=8,\n", " collate_fn=fnet_collator,\n", " pin_memory=True,\n", " worker_init_fn=seed_worker,\n", " generator=g,\n", ")\n", "\n", "val_dataloader = DataLoader(\n", " original_datasets[\"validation\"],\n", " batch_size=FNET_PHYSICAL_BATCH_SIZE,\n", " collate_fn=fnet_collator,\n", " num_workers=8,\n", " pin_memory=True,\n", " worker_init_fn=seed_worker,\n", " generator=g,\n", ")\n", "\n", "print(f\"Train Dataloader is now a STANDARD iterator.\")\n", "print(f\"Physical Batch Size: {FNET_PHYSICAL_BATCH_SIZE}\")\n", "print(f\"Gradient Accumulation: {GRAD_ACCUMULATION_STEPS}\")\n", "print(f\"Effective Batch Size: {FNET_PHYSICAL_BATCH_SIZE * GRAD_ACCUMULATION_STEPS}\")\n", "\n", "# --- SANITY CHECK ---\n", "print(\"\\n--- Running Sanity Check on new FNet DataLoader ---\")\n", "train_dataloader.generator.manual_seed(SEED)\n", "temp_iterator = iter(train_dataloader)\n", "print(\"Shapes of first 3 batches (Should all be [64, 128]):\")\n", "for i in range(3):\n", " batch = next(temp_iterator)\n", " print(f\" Batch {i+1}: input_ids shape = {batch['input_ids'].shape}\")\n", "print(\"--- Sanity Check Complete ---\\n\")\n", "# --- VERIFY SHUFFLE IS WORKING ---\n", "print(\"šŸ•µļø INSPECTING ONE BATCH šŸ•µļø\")\n", "\n", "# Get one batch from your active train_dataloader\n", "batch = next(iter(train_dataloader))\n", "input_ids = batch['input_ids']\n", "\n", "# Calculate real lengths (ignoring padding)\n", "# We count how many tokens are NOT the pad token (usually 0 or 58100)\n", "real_lengths = (input_ids != tokenizer.pad_token_id).sum(dim=1)\n", "\n", "print(f\"Batch Shape: {input_ids.shape}\")\n", "print(\"Random Sample of 20 lengths in this batch:\")\n", "print(real_lengths[:20].tolist())\n", "\n", "# Check diversity\n", "if real_lengths.float().std() < 5:\n", " print(\"\\nāš ļø WARNING: LENGTHS LOOK CLUSTERED! (Bad shuffling)\")\n", "else:\n", " print(f\"\\nāœ… PASSED: Lengths are highly variable (Std Dev: {real_lengths.float().std():.2f}). Shuffling is working.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "cS4JvJGRhClv" }, "source": [ "## Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SMhlM0YvO1A7" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import math\n", "from x_transformers import Encoder, Decoder\n", "\n", "class RoPETransformer(nn.Module):\n", " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.embedding = nn.Embedding(vocab_size, d_model)\n", "\n", " # We REMOVE self.pos_encoder (RoPE handles position internally)\n", " self.dropout_layer = nn.Dropout(dropout)\n", "\n", " # --- x-transformers Encoder ---\n", " self.encoder = Encoder(\n", " dim = d_model,\n", " depth = num_encoder_layers,\n", " heads = num_heads,\n", " attn_dim_head = d_model // num_heads,\n", " ff_mult = dff / d_model,\n", " rotary_pos_emb = True,\n", " attn_flash = True,\n", " attn_dropout = dropout,\n", " ff_dropout = dropout,\n", " use_rmsnorm = True\n", " )\n", "\n", " # --- x-transformers Decoder ---\n", " self.decoder = Decoder(\n", " dim = d_model,\n", " depth = num_decoder_layers,\n", " heads = num_heads,\n", " attn_dim_head = d_model // num_heads,\n", " ff_mult = dff / d_model,\n", " rotary_pos_emb = True,\n", " cross_attend = True,\n", " attn_flash = True,\n", " attn_dropout = dropout,\n", " ff_dropout = dropout,\n", " use_rmsnorm = True\n", " )\n", "\n", " self.final_linear = nn.Linear(d_model, vocab_size)\n", " self.final_linear.weight = self.embedding.weight\n", "\n", " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n", " # 1. Embeddings (No Absolute Positional Encoding added!)\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " src_emb = self.dropout_layer(src_emb)\n", "\n", " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n", " tgt_emb = self.dropout_layer(tgt_emb)\n", "\n", " # 2. Mask Conversion\n", " # User provides True=PAD. x-transformers wants True=KEEP.\n", " # We invert the boolean mask using ~\n", " enc_mask = ~src_padding_mask if src_padding_mask is not None else None\n", " dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None\n", "\n", " # Note: 'tgt_mask' (causal mask) is handled automatically by x-transformers Decoder!\n", " # We do NOT pass the square causal mask manually.\n", "\n", " # 3. Encoder\n", " # x-transformers takes embeddings directly\n", " memory = self.encoder(src_emb, mask=enc_mask)\n", "\n", " # 4. Decoder\n", " # context = memory (from encoder)\n", " # context_mask = mask for memory (encoder mask)\n", " decoder_output = self.decoder(\n", " tgt_emb,\n", " context=memory,\n", " mask=dec_mask,\n", " context_mask=enc_mask\n", " )\n", "\n", " return self.final_linear(decoder_output)\n", "\n", " # Keep your existing create_masks (used for Data Processing mostly)\n", " def create_masks(self, src, tgt):\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n", " # We still generate this for compatibility, though x-transformers handles causality internally\n", " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n", " sz=tgt.size(1), device=src.device, dtype=torch.bool\n", " )\n", " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n", "\n", " @torch.no_grad()\n", " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n", " self.eval()\n", " # Create Mask (True=PAD)\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", " # Invert for x-transformers (True=KEEP)\n", " enc_mask = ~src_padding_mask\n", "\n", " # Encode\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " # No Pos Encoder\n", " memory = self.encoder(self.dropout_layer(src_emb), mask=enc_mask)\n", "\n", " batch_size = src.shape[0]\n", " # Expand for beams\n", " memory = memory.repeat_interleave(num_beams, dim=0)\n", " enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)\n", "\n", " initial_token = tokenizer.pad_token_id\n", " beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n", " beam_scores = torch.zeros(batch_size * num_beams, device=src.device)\n", " finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)\n", "\n", " for _ in range(max_length - 1):\n", " if finished_beams.all(): break\n", "\n", " # Embed beams\n", " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)\n", " # No Pos Encoder\n", "\n", " # Decode\n", " # x-transformers automatically handles the causal masking for the sequence length of tgt_emb\n", " decoder_output = self.decoder(\n", " self.dropout_layer(tgt_emb),\n", " context=memory,\n", " context_mask=enc_mask\n", " )\n", "\n", " logits = self.final_linear(decoder_output[:, -1, :])\n", " log_probs = F.log_softmax(logits, dim=-1)\n", "\n", " # ... (Rest of your Beam Search Logic remains identical) ...\n", " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n", " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n", "\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", " if _ == 0:\n", " total_scores = total_scores.view(batch_size, num_beams, -1)\n", " total_scores[:, 1:, :] = -torch.inf\n", " total_scores = total_scores.view(batch_size * num_beams, -1)\n", " else:\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", "\n", " total_scores = total_scores.view(batch_size, -1)\n", " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n", "\n", " beam_indices = top_indices // log_probs.shape[-1]\n", " token_indices = top_indices % log_probs.shape[-1]\n", "\n", " batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)\n", " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n", "\n", " beams = beams[effective_indices]\n", " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n", " beam_scores = top_scores.view(-1)\n", " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n", "\n", " final_beams = beams.view(batch_size, num_beams, -1)\n", " final_scores = beam_scores.view(batch_size, num_beams)\n", " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n", " best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]\n", " self.train()\n", " return best_beams\n", "\n", "class RMSNorm(nn.Module):\n", " def __init__(self, dim, eps=1e-8):\n", " super().__init__()\n", " self.eps = eps\n", " self.gamma = nn.Parameter(torch.ones(dim))\n", "\n", " def forward(self, x):\n", " # 1. Calculate the mean of the squares\n", " mean_square = x.pow(2).mean(dim=-1, keepdim=True)\n", "\n", " # 2. Calculate the inverse square root (1 / RMS)\n", " # We add eps before the sqrt for stability\n", " inv_rms = torch.rsqrt(mean_square + self.eps)\n", "\n", " # 3. Normalize and scale\n", " return x * inv_rms * self.gamma\n", "\n", "\n", "class FNetBlock(nn.Module):\n", " def __init__(self, d_model, d_ff, dropout):\n", " super().__init__()\n", " self.norm_mix = nn.LayerNorm(d_model) # LayerNorm is safer for FNet than RMSNorm\n", " self.norm_ff = nn.LayerNorm(d_model)\n", "\n", " self.ff = nn.Sequential(\n", " nn.Linear(d_model, d_ff),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(d_ff, d_model),\n", " nn.Dropout(dropout)\n", " )\n", "\n", " def forward(self, x):\n", " # 1. Fourier Mixing Branch\n", " residual = x\n", " x = self.norm_mix(x)\n", "\n", " # --- THE FIX ---\n", " with torch.cuda.amp.autocast(enabled=False):\n", " x = x.float()\n", " # norm='ortho' makes the FFT energy-preserving.\n", " # Output magnitude will match input magnitude (~1).\n", " x = torch.fft.fftn(x, dim=(-2, -1), norm='ortho').real\n", " x = x.to(dtype=residual.dtype)\n", " # ---------------\n", "\n", " # Now 'x' and 'residual' have roughly same magnitude.\n", " # The skip connection works again.\n", " x = x + residual\n", "\n", " # 2. Feed Forward Branch\n", " residual = x\n", " x = self.norm_ff(x)\n", " x = self.ff(x)\n", " return x + residual\n", "\n", "\n", "class FNetEncoder(nn.Module):\n", " def __init__(self, depth, d_model, d_ff, dropout):\n", " super().__init__()\n", " self.layers = nn.ModuleList([\n", " FNetBlock(d_model, d_ff, dropout) for _ in range(depth)\n", " ])\n", " # [FIX] Use LayerNorm here to match the blocks\n", " self.norm_out = nn.LayerNorm(d_model)\n", "\n", " def forward(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return self.norm_out(x)\n", "\n", "# --- Main Hybrid Model ---\n", "\n", "class FNetHybridTransformer(nn.Module):\n", " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n", " super().__init__()\n", " self.d_model = d_model\n", "\n", " # Shared Embeddings\n", " # padding_idx=tokenizer.pad_token_id forces the vector at this index to be strict ZEROS.\n", " # It does not have gradients, it stays zero forever.\n", " self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=tokenizer.pad_token_id)\n", "\n", " # FNet REQUIRES Absolute Positional Embeddings because FFT mixes information\n", " # but doesn't inherently understand sequence order like RoPE/RNNs do initially.\n", " self.pos_embedding = nn.Embedding(max_length, d_model)\n", "\n", " self.dropout_layer = nn.Dropout(dropout)\n", "\n", " # --- Custom FNet Encoder ---\n", " self.encoder = FNetEncoder(\n", " depth=num_encoder_layers,\n", " d_model=d_model,\n", " d_ff=dff,\n", " dropout=dropout\n", " )\n", "\n", " # --- x-transformers Decoder (Retains RoPE) ---\n", " self.decoder = Decoder(\n", " dim=d_model,\n", " depth=num_decoder_layers,\n", " heads=num_heads,\n", " attn_dim_head=d_model // num_heads,\n", " ff_mult=dff / d_model,\n", " rotary_pos_emb=True, # Decoder still uses RoPE\n", " cross_attend=True,\n", " attn_flash=True,\n", " attn_dropout=dropout,\n", " ff_dropout=dropout,\n", " use_rmsnorm=True\n", " )\n", "\n", " self.final_linear = nn.Linear(d_model, vocab_size)\n", " self.final_linear.weight = self.embedding.weight\n", "\n", " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n", " # 1. Embeddings\n", " # Source (Encoder) gets Absolute Positional Embeddings\n", " B, L_src = src.shape\n", " pos_ids = torch.arange(L_src, device=src.device).unsqueeze(0)\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " src_emb = src_emb + self.pos_embedding(pos_ids)\n", " src_emb = self.dropout_layer(src_emb)\n", "\n", " # Target (Decoder) gets NO Positional Embeddings here (RoPE handles it inside Decoder)\n", " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n", " tgt_emb = self.dropout_layer(tgt_emb)\n", "\n", " # 2. Prepare Masks\n", " # x-transformers requires True = Keep, False = Mask\n", " # Your dataloader provides True = Pad\n", " enc_mask = ~src_padding_mask if src_padding_mask is not None else None\n", " dec_mask = ~tgt_padding_mask if tgt_padding_mask is not None else None\n", "\n", " # 3. FNet Encoder\n", " # Note: FNet mixes ALL tokens (including padding).\n", " memory = self.encoder(src_emb)\n", "\n", " # CRITICAL: Zero out padding positions in encoder output so Decoder doesn't attend to them.\n", " if src_padding_mask is not None:\n", " memory = memory.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)\n", "\n", " # 4. RoPE Decoder\n", " # The decoder uses RoPE for self-attention on 'tgt',\n", " # and standard cross-attention to 'memory' (FNet output).\n", " decoder_output = self.decoder(\n", " tgt_emb,\n", " context=memory,\n", " mask=dec_mask,\n", " context_mask=enc_mask\n", " )\n", "\n", " return self.final_linear(decoder_output)\n", "\n", " def create_masks(self, src, tgt):\n", " # Standard mask creation (Same as your original)\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n", " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n", " sz=tgt.size(1), device=src.device, dtype=torch.bool\n", " )\n", " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n", "\n", " @torch.no_grad()\n", " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n", " self.eval()\n", " B, L_src = src.shape\n", "\n", " # 1. Encode with FNet\n", " pos_ids = torch.arange(L_src, device=src.device).unsqueeze(0)\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " src_emb = src_emb + self.pos_embedding(pos_ids)\n", "\n", " memory = self.encoder(self.dropout_layer(src_emb))\n", "\n", " # Masking padding in memory\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", " memory = memory.masked_fill(src_padding_mask.unsqueeze(-1), 0.0)\n", "\n", " # Prepare for Decoder (x-transformers style mask: True=Keep)\n", " enc_mask = ~src_padding_mask\n", "\n", " # --- BEAM SEARCH SETUP ---\n", " # Expand memory for beams\n", " memory = memory.repeat_interleave(num_beams, dim=0)\n", " enc_mask = enc_mask.repeat_interleave(num_beams, dim=0)\n", "\n", " initial_token = tokenizer.pad_token_id\n", " beams = torch.full((B * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n", " beam_scores = torch.zeros(B * num_beams, device=src.device)\n", " finished_beams = torch.zeros(B * num_beams, dtype=torch.bool, device=src.device)\n", "\n", " for _ in range(max_length - 1):\n", " if finished_beams.all(): break\n", "\n", " # Decoder Step (RoPE handled internally)\n", " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model)\n", "\n", " decoder_output = self.decoder(\n", " self.dropout_layer(tgt_emb),\n", " context=memory,\n", " context_mask=enc_mask\n", " )\n", "\n", " logits = self.final_linear(decoder_output[:, -1, :])\n", " log_probs = F.log_softmax(logits, dim=-1)\n", "\n", " # --- STANDARD BEAM LOGIC (No changes needed here) ---\n", " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n", " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n", "\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", " if _ == 0:\n", " total_scores = total_scores.view(B, num_beams, -1)\n", " total_scores[:, 1:, :] = -torch.inf\n", " total_scores = total_scores.view(B * num_beams, -1)\n", " else:\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", "\n", " total_scores = total_scores.view(B, -1)\n", " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n", "\n", " beam_indices = top_indices // log_probs.shape[-1]\n", " token_indices = top_indices % log_probs.shape[-1]\n", "\n", " batch_indices = torch.arange(B, device=src.device).unsqueeze(1)\n", " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n", "\n", " beams = beams[effective_indices]\n", " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n", " beam_scores = top_scores.view(-1)\n", " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n", "\n", " final_beams = beams.view(B, num_beams, -1)\n", " final_scores = beam_scores.view(B, num_beams)\n", " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n", " best_beams = final_beams[torch.arange(B), normalized_scores.argmax(1), :]\n", " self.train()\n", " return best_beams" ] }, { "cell_type": "code", "source": [ "def count_parameters(model):\n", " table_data = []\n", " total_params = 0\n", " trainable_params = 0\n", "\n", " # 1. Global Counts\n", " for p in model.parameters():\n", " total_params += p.numel()\n", " if p.requires_grad:\n", " trainable_params += p.numel()\n", "\n", " print(\"=\"*40)\n", " print(f\"šŸ“Š MODEL STATISTICS\")\n", " print(\"=\"*40)\n", " print(f\"Total Parameters: {total_params:,} ({total_params/1e6:.2f}M)\")\n", " print(f\"Trainable Parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)\")\n", " print(\"-\" * 40)\n", "\n", " # 2. Section Breakdown\n", " def get_params(module):\n", " return sum(p.numel() for p in module.parameters())\n", "\n", " if hasattr(model, 'encoder'):\n", " enc_p = get_params(model.encoder)\n", " print(f\" • Encoder (FNet): {enc_p:,} ({enc_p/1e6:.2f}M)\")\n", "\n", " if hasattr(model, 'decoder'):\n", " dec_p = get_params(model.decoder)\n", " print(f\" • Decoder (RoPE): {dec_p:,} ({dec_p/1e6:.2f}M)\")\n", "\n", " if hasattr(model, 'embedding'):\n", " emb_p = get_params(model.embedding)\n", " print(f\" • Embeddings: {emb_p:,} ({emb_p/1e6:.2f}M)\")\n", "\n", " print(\"=\"*40)\n", "\n" ], "metadata": { "id": "wpmz-H9Slko1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Zd3AFTmhrCJq" }, "source": [ "## Functions (Loss, Eval etc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Te1qTyUKrDEd" }, "outputs": [], "source": [ "\n", "translation_loss_fn = nn.CrossEntropyLoss(\n", " ignore_index=-100, # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.\n", " label_smoothing=LABEL_SMOOTHING_EPSILON\n", ")\n", "def calculate_combined_loss(model_outputs, target_labels):\n", " \"\"\"Calculates the loss based on the model's output structure.\"\"\"\n", " logits = model_outputs\n", " translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))\n", " loss_dict = {'total': translation_loss.item()}\n", " return translation_loss, loss_dict\n", "\n", "from torchmetrics.text import SacreBLEUScore\n", "\n", "def evaluate(model, dataloader, device):\n", " # Use SacreBLEUScore (defaults to '13a' tokenizer, the WMT standard)\n", " metric = SacreBLEUScore().to(device)\n", "\n", " model.eval()\n", "\n", " # Use no_grad to save memory and speed up validation\n", " with torch.no_grad():\n", " for batch in tqdm(dataloader, desc=\"Evaluating\", leave=False):\n", " input_ids = batch['input_ids'].to(device)\n", " labels = batch['labels']\n", "\n", " # Generate predictions\n", " generated_ids = model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n", "\n", " # Decode predictions\n", " pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", "\n", " # Decode labels (Fixing -100 padding)\n", " labels[labels == -100] = tokenizer.pad_token_id\n", " ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Update Metric\n", " # SacreBLEU expects references as a list of lists: [[ref1], [ref2], ...]\n", " formatted_refs = [[ref] for ref in ref_texts]\n", " metric.update(pred_texts, formatted_refs)\n", "\n", " model.train()\n", "\n", " # Compute returns a tensor, .item() converts it to a standard python float\n", " return metric.compute().item()\n", "\n", "\n", "\n", "## WARNING! THIS CAN'T BE USED FOR FNET\n", "def generate_sample_translations(model, device, sentences_de):\n", " \"\"\"Generates and prints sample translations using beam search.\"\"\"\n", " print(\"\\n--- Generating Sample Translations (with Beam Search) ---\")\n", " orig_model = getattr(model, '_orig_mod', model)\n", " orig_model.eval()\n", "\n", " inputs = tokenizer(sentences_de, return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_LENGTH)\n", " input_ids = inputs.input_ids.to(device)\n", " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n", "\n", " translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", " for src, out in zip(sentences_de, translations):\n", " print(f\" DE Source: {src}\")\n", " print(f\" EN Output: {out}\")\n", " print(\"-\" * 20)\n", " orig_model.train()\n", "\n", "sample_sentences_de_for_tracking = [\n", " \"Eine Katze sitzt auf der Matte.\",\n", " \"Ein Mann in einem roten Hemd liest ein Buch.\",\n", " \"Was ist die Hauptstadt von Deutschland?\",\n", " \"Ich gehe ins Kino, weil der Film sehr gut ist.\",\n", "]\n", "\n", "def init_other_linear_weights(m):\n", " if isinstance(m, nn.Linear):\n", " # The 'is not' check correctly skips the final_linear layer,\n", " # leaving its weights tied to the correctly initialized embeddings.\n", " if m is not getattr(model, '_orig_mod', model).final_linear:\n", " nn.init.xavier_uniform_(m.weight)\n", " if m.bias is not None:\n", " nn.init.zeros_(m.bias)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YwPXbSwR50I2" }, "outputs": [], "source": [ "import json\n", "import os\n", "import subprocess\n", "import torch\n", "import hashlib\n", "import sys\n", "import shutil\n", "\n", "# This logger will be configured and used in the main training script\n", "import logging\n", "logger = logging.getLogger(__name__)\n", "\n", "\n", "def log_to_run_specific_file(run_dir):\n", " run_log_path = os.path.join(run_dir, \"run_log.txt\")\n", " file_handler = logging.FileHandler(run_log_path)\n", " file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))\n", " logger.addHandler(file_handler)\n", " return file_handler\n", "\n", "def log_configurations(log_dir, config_vars):\n", " # (Same as your provided function)\n", " config_path = os.path.join(log_dir, \"config.json\")\n", " try:\n", " with open(config_path, 'w') as f:\n", " serializable_configs = {k: v for k, v in config_vars.items() if isinstance(v, (int, float, str, bool, list, dict, type(None)))}\n", " json.dump(serializable_configs, f, indent=4)\n", " logger.info(f\"Configurations saved to {config_path}\")\n", " except Exception as e:\n", " logger.error(f\"Could not save configurations: {e}\")\n", "\n", "def log_environment(log_dir):\n", " # (Same as your provided function)\n", " env_path = os.path.join(log_dir, \"environment.txt\")\n", " try:\n", " with open(env_path, 'w') as f:\n", " f.write(f\"--- Timestamp (UTC): {datetime.datetime.utcnow().isoformat()} ---\\n\")\n", " f.write(f\"Python Version: {sys.version}\\n\")\n", " f.write(f\"PyTorch Version: {torch.__version__}\\n\")\n", " f.write(f\"CUDA Available: {torch.cuda.is_available()}\\n\")\n", " if torch.cuda.is_available():\n", " f.write(f\"CUDA Version: {torch.version.cuda}\\n\")\n", " f.write(f\"CuDNN Version: {torch.backends.cudnn.version()}\\n\")\n", " f.write(f\"Number of GPUs: {torch.cuda.device_count()}\\n\")\n", " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n", " f.write(\"\\n--- Full pip freeze ---\\n\")\n", " result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'], stdout=subprocess.PIPE, text=True, check=True)\n", " f.write(result.stdout)\n", " logger.info(f\"Environment info saved to {env_path}\")\n", " except Exception as e:\n", " logger.error(f\"Could not save environment info: {e}\")\n", "\n", "def log_code_snapshot(log_dir, script_path):\n", " # NOTE: In Colab, you must save your notebook as a .py file for this to work.\n", " # For example, file -> \"Save a copy as .py\"\n", " code_dir = os.path.join(log_dir, \"code_snapshot\")\n", " os.makedirs(code_dir, exist_ok=True)\n", " if script_path and os.path.exists(script_path):\n", " try:\n", " shutil.copy(script_path, os.path.join(code_dir, os.path.basename(script_path)))\n", " logger.info(f\"Copied script '{script_path}' to snapshot directory for verification.\")\n", " except Exception as e:\n", " logger.error(f\"Could not copy script for snapshot: {e}\")\n", " else:\n", " logger.warning(f\"Code Snapshot: Script path '{script_path}' not found. SKIPPING.\")\n", "\n", "def get_file_hash(filepath):\n", " # (Same as your provided function)\n", " sha256_hash = hashlib.sha256()\n", " try:\n", " with open(filepath, \"rb\") as f:\n", " for byte_block in iter(lambda: f.read(4096), b\"\"):\n", " sha256_hash.update(byte_block)\n", " return sha256_hash.hexdigest()\n", " except Exception as e:\n", " logger.error(f\"Could not generate hash for {filepath}: {e}\")\n", " return None\n", "\n", "def create_checksum_file(run_dir, artifacts_dict):\n", " checksum_file_path = os.path.join(run_dir, \"checksums.sha256\")\n", " logger.info(f\"--- Creating digital fingerprints for key artifacts ---\")\n", " with open(checksum_file_path, \"w\") as f:\n", " f.write(f\"SHA256 Checksums for run: {os.path.basename(run_dir)}\\n\")\n", " for name, path in artifacts_dict.items():\n", " if path and os.path.exists(path):\n", " file_hash = get_file_hash(path)\n", " if file_hash:\n", " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n", " logger.info(log_message)\n", " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n", " else:\n", " logger.warning(f\" - Skipped hashing '{name}', file not found: {path}\")\n", " logger.info(f\"Checksums saved to {checksum_file_path}\")\n", "\n", "def init_weights_kaiming(m):\n", " \"\"\"\n", " Applies Kaiming He initialization to Linear layers.\n", " This is the standard, superior way to initialize deep Transformers.\n", " NOTE: We will handle the Embedding layer separately.\n", " \"\"\"\n", "\n", " if isinstance(m, nn.Linear):\n", " nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) # a=sqrt(5) mimics default PyTorch for LeakyReLU\n", " if m.bias is not None:\n", " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)\n", " bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n", " nn.init.uniform_(m.bias, -bound, bound)\n", "\n", "\n", "def init_weights_fnet(m):\n", " \"\"\"\n", " Specific initialization for FNet Hybrid.\n", " FNet is essentially a BERT-like encoder, so we use BERT-style initialization\n", " (Truncated Normal or Xavier) rather than Kaiming.\n", " \"\"\"\n", " if isinstance(m, nn.Linear):\n", " # Xavier (Glorot) Uniform is the standard for Transformer/FNet attention/FFN layers\n", " nn.init.xavier_uniform_(m.weight)\n", " if m.bias is not None:\n", " nn.init.zeros_(m.bias)\n", "\n", " elif isinstance(m, nn.Embedding):\n", " # Critical: Keep embedding variance low (0.02)\n", " nn.init.normal_(m.weight, mean=0.0, std=0.02)\n", "\n", " # Handle the RMSNorms if they have learnable parameters\n", " elif isinstance(m, (nn.LayerNorm, RMSNorm)):\n", " if hasattr(m, 'weight') and m.weight is not None:\n", " nn.init.ones_(m.weight)\n", " if hasattr(m, 'bias') and m.bias is not None:\n", " nn.init.zeros_(m.bias)\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ijTUk5dHu494" }, "source": [ "## Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pyHZ1moluyA2" }, "outputs": [], "source": [ "if __name__ == '__main__':\n", "\n", " experiment_name = f\"{MODEL_CHOICE}\"\n", " CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name)\n", " SAVE_DIR = os.path.join(CURRENT_RUN_DIR, \"models\")\n", " LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, \"tensorboard_logs\")\n", " LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, \"run_log.txt\")\n", "\n", " os.makedirs(SAVE_DIR, exist_ok=True)\n", " os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)\n", "\n", " logging.basicConfig(\n", " level=logging.INFO,\n", " format='%(asctime)s [%(levelname)s] %(message)s',\n", " handlers=[logging.FileHandler(LOG_FILE_TXT), logging.StreamHandler(sys.stdout)],\n", " force=True\n", " )\n", " logger = logging.getLogger(__name__)\n", " writer = SummaryWriter(LOG_DIR_TENSORBOARD)\n", "\n", " logger.info(f\"--- LAUNCHING EXPERIMENT: {experiment_name} ---\")\n", "\n", " all_configs = {k: v for k, v in globals().items() if k.isupper()}\n", " log_configurations(CURRENT_RUN_DIR, all_configs)\n", " log_environment(CURRENT_RUN_DIR)\n", "\n", " logger.info(f\"--- Initializing FNetHybridTransformer ---\")\n", " model = FNetHybridTransformer(\n", " num_encoder_layers=NUM_ENCODER_LAYERS,\n", " num_decoder_layers=NUM_DECODER_LAYERS,\n", " num_heads=NUM_HEADS,\n", " d_model=D_MODEL,\n", " dff=D_FF,\n", " vocab_size=VOCAB_SIZE,\n", " max_length=MAX_LENGTH,\n", " dropout=DROPOUT\n", " )\n", "\n", " model.apply(init_weights_fnet)\n", " nn.init.normal_(model.pos_embedding.weight, mean=0.0, std=0.02)\n", " model.final_linear.weight = model.embedding.weight\n", "\n", " model.to(device)\n", " count_parameters(model)\n", "\n", " # 4. SETUP OPTIMIZER\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, betas=(0.9, 0.98),\n", " eps=1e-9, weight_decay=WEIGHT_DECAY)\n", "\n", " # Scheduler\n", " scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=WARMUP_STEPS,\n", " num_training_steps=TARGET_TRAINING_STEPS)\n", " scaler = torch.cuda.amp.GradScaler()\n", "\n", "# --- AUTO-RESUME LOGIC (SMARTER VERSION) ---\n", " global_step = 0\n", " best_bleu = 0.0\n", " LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"last.pt\")\n", " BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"best.pt\")\n", "\n", " # 1. Try to find the latest checkpoint (if it exists)\n", " if os.path.exists(LAST_CHECKPOINT_PATH):\n", " logger.info(f\"šŸ”„ Found checkpoint at {LAST_CHECKPOINT_PATH}. Resuming...\")\n", " checkpoint = torch.load(LAST_CHECKPOINT_PATH, map_location=device)\n", "\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", " scaler.load_state_dict(checkpoint['scaler_state_dict'])\n", "\n", " global_step = checkpoint['global_step']\n", " best_bleu = checkpoint.get('best_bleu', 0.0)\n", " logger.info(f\" āœ… Resumed from Step {global_step} (LAST)\")\n", "\n", " # 2. If no LAST, try to find the BEST checkpoint (Fall back to this!)\n", " elif os.path.exists(BEST_CHECKPOINT_PATH):\n", " logger.info(f\"šŸ”™ 'last.pt' not found. Falling back to BEST checkpoint: {BEST_CHECKPOINT_PATH}\")\n", " checkpoint = torch.load(BEST_CHECKPOINT_PATH, map_location=device)\n", "\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n", " scaler.load_state_dict(checkpoint['scaler_state_dict'])\n", "\n", " global_step = checkpoint['global_step']\n", " best_bleu = checkpoint.get('best_bleu', 0.0)\n", " logger.info(f\" āœ… Resumed from Step {global_step} (BEST)\")\n", "\n", " # 3. Start Fresh\n", " else:\n", " logger.info(\"šŸ†• No checkpoint found. Starting fresh training.\")\n", " # 5. TRAINING LOOP\n", " model.train()\n", "\n", " # Resume progress bar from global_step\n", " progress_bar = tqdm(total=TARGET_TRAINING_STEPS, initial=global_step, desc=\"Training Steps\")\n", " training_complete = False\n", "\n", " # Initialize gradients\n", " optimizer.zero_grad(set_to_none=True)\n", "\n", " # We iterate until global_step reaches the target\n", " epoch = 0\n", " while not training_complete:\n", " train_dataloader.generator.manual_seed(SEED + epoch)\n", " epoch += 1\n", "\n", " for batch_idx, batch in enumerate(train_dataloader):\n", " if global_step >= TARGET_TRAINING_STEPS:\n", " training_complete = True\n", " break\n", "\n", " input_ids = batch['input_ids'].to(device, non_blocking=True)\n", " labels = batch['labels'].to(device, non_blocking=True)\n", "\n", " decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)\n", " decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)\n", " decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id\n", " target_labels = labels\n", "\n", " src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)\n", " tgt_padding_mask[:, 0] = False\n", "\n", " with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n", " model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask,\n", " tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask,\n", " tgt_mask=tgt_mask)\n", " loss, loss_components = calculate_combined_loss(model_outputs, target_labels)\n", "\n", " # --- GRADIENT ACCUMULATION SCALING ---\n", " loss = loss / GRAD_ACCUMULATION_STEPS\n", "\n", " # Accumulate gradients (no optimizer step yet)\n", " scaler.scale(loss).backward()\n", "\n", " # --- OPTIMIZER STEP (Conditional) ---\n", " if (batch_idx + 1) % GRAD_ACCUMULATION_STEPS == 0:\n", " scaler.unscale_(optimizer)\n", " total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", "\n", " scaler.step(optimizer)\n", " scaler.update()\n", " scheduler.step()\n", "\n", " # Reset gradients\n", " optimizer.zero_grad(set_to_none=True)\n", "\n", " global_step += 1\n", " progress_bar.update(1)\n", " lr = scheduler.get_last_lr()[0]\n", "\n", " if global_step % 20 == 0:\n", " # Scale loss back up for logging purposes\n", " logged_loss = loss.item() * GRAD_ACCUMULATION_STEPS\n", " writer.add_scalar('train/loss', logged_loss, global_step)\n", " writer.add_scalar('train/learning_rate', lr, global_step)\n", " writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step)\n", " progress_bar.set_postfix(\n", " loss=f\"{logged_loss:.2f}\",\n", " lr=f\"{lr:.2e}\",\n", " grad=f\"{total_grad_norm.item():.2f}\" # Showing Gradients\n", " )\n", "\n", " # --- PERIODIC SAVING (Every 500 Steps) ---\n", " # Saves you if Colab crashes mid-epoch\n", " if global_step % 500 == 0:\n", " torch.save({\n", " 'global_step': global_step,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'scaler_state_dict': scaler.state_dict(),\n", " 'best_bleu': best_bleu\n", " }, LAST_CHECKPOINT_PATH)\n", "\n", " # --- VALIDATION CHECK ---\n", " if global_step in VALIDATION_SCHEDULE:\n", " logger.info(f\"\\n--- Validation at Step {global_step} ---\")\n", " bleu_score = evaluate(model, val_dataloader, device)\n", " writer.add_scalar('validation/bleu', bleu_score, global_step)\n", " logger.info(f\"Validation BLEU: {bleu_score:.4f} (Best: {best_bleu:.4f})\")\n", " #generate_sample_translations(model, device, sample_sentences_de_for_tracking)\n", "\n", " if bleu_score > best_bleu:\n", " best_bleu = bleu_score\n", " logger.info(f\" New best BLEU! Saving best model...\")\n", " # Save EVERYTHING so you can resume even from best model\n", " torch.save({\n", " 'global_step': global_step,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'scaler_state_dict': scaler.state_dict(),\n", " 'best_bleu': best_bleu\n", " }, BEST_CHECKPOINT_PATH)\n", "\n", " model.train()\n", "\n", " progress_bar.close()\n", " writer.close()\n", "\n", " # Save Final (With States)\n", " torch.save({\n", " 'global_step': global_step,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'scheduler_state_dict': scheduler.state_dict(),\n", " 'scaler_state_dict': scaler.state_dict(),\n", " 'best_bleu': best_bleu\n", " }, LAST_CHECKPOINT_PATH)\n", "\n", " print(\"\\n\" + \"*\"*80)\n", " print(\" EXPERIMENT COMPLETE \")\n", " print(\"*\"*80)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UsS6qhLtJaMF" }, "outputs": [], "source": [ "import os\n", "import sys\n", "import torch\n", "import transformers\n", "import datasets\n", "import torchmetrics\n", "import numpy\n", "import pkg_resources\n", "\n", "def log_environment_separate(log_dir):\n", " # Define the separate file path\n", " meta_file = os.path.join(log_dir, \"system_metadata.txt\")\n", "\n", " with open(meta_file, \"w\") as f:\n", " # --- PART 1: SUMMARY ---\n", " f.write(\"=\"*40 + \"\\n\")\n", " f.write(\"CORE ENVIRONMENT SUMMARY\\n\")\n", " f.write(\"=\"*40 + \"\\n\")\n", " f.write(f\"Python: {sys.version.split()[0]}\\n\")\n", " f.write(f\"PyTorch: {torch.__version__}\\n\")\n", " f.write(f\"Transformers: {transformers.__version__}\\n\")\n", " f.write(f\"Datasets: {datasets.__version__}\\n\")\n", " f.write(f\"TorchMetrics: {torchmetrics.__version__}\\n\")\n", " f.write(f\"NumPy: {numpy.__version__}\\n\")\n", "\n", " try:\n", " import sacrebleu\n", " f.write(f\"SacreBLEU: {sacrebleu.__version__}\\n\")\n", " except ImportError:\n", " f.write(\"SacreBLEU: Not Installed\\n\")\n", "\n", " if torch.cuda.is_available():\n", " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n", " f.write(f\"CUDA Ver: {torch.version.cuda}\\n\")\n", " f.write(f\"Capability: {torch.cuda.get_device_capability(0)}\\n\")\n", " else:\n", " f.write(\"GPU: None (CPU Only)\\n\")\n", "\n", " # --- PART 2: FULL FREEZE ---\n", " f.write(\"\\n\" + \"=\"*40 + \"\\n\")\n", " f.write(\"FULL LIBRARY DEPENDENCIES (PIP FREEZE)\\n\")\n", " f.write(\"=\"*40 + \"\\n\")\n", "\n", " installed_packages = {d.project_name: d.version for d in pkg_resources.working_set}\n", " for package, version in sorted(installed_packages.items()):\n", " f.write(f\"{package}=={version}\\n\")\n", "\n", " print(f\"āœ… Environment details saved SEPARATELY to: {meta_file}\")\n", "\n", "# Execute\n", "# Assumes CURRENT_RUN_DIR is defined from your config\n", "log_environment_separate(CURRENT_RUN_DIR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tqDiOyy18clU" }, "outputs": [], "source": [ "# TENSORBOARD VISUALIZATION\n", "\n", "%load_ext tensorboard\n", "\n", "TENSORBOARD_BASE_DIR = os.path.join(DRIVE_BASE_PATH)\n", "\n", "%tensorboard --logdir \"{TENSORBOARD_BASE_DIR}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AmOcgwNnJqOj" }, "outputs": [], "source": [ "from google.colab import runtime\n", "runtime.unassign()" ] }, { "cell_type": "markdown", "metadata": { "id": "eI0-qVlWVVpx" }, "source": [ "## End" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "provenance": [], "machine_shape": "hm" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }