{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "id": "2s48Vmoo9EB5" }, "outputs": [], "source": [ "!pip install -q torchmetrics sacrebleu" ] }, { "cell_type": "markdown", "metadata": { "id": "Lz8buKsjvA_w" }, "source": [ "## CONFIG" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "df355sdDrNSb" }, "outputs": [], "source": [ "# --- Data & Task Size ---\n", "MAX_LENGTH = 128\n", "\n", "MODEL_CHOICE = \"Baseline\" # For save path\n", "\n", "# --- Model Architecture Config (\"Transformer-Small\") ---\n", "D_MODEL = 512\n", "NUM_HEADS = 8\n", "D_FF = 2048\n", "DROPOUT = 0.1\n", "\n", "# --- Layer counts ---\n", "NUM_ENCODER_LAYERS = 6\n", "NUM_DECODER_LAYERS = 6\n", "\n", "# --- Training Config ---\n", "TARGET_TRAINING_STEPS = 50000\n", "\n", "VALIDATION_SCHEDULE = [\n", " 2000, 4000, 5000, 7500, 10000, 15000, 20000,\n", " 25000, 30000, 35000, 42500, 50000\n", "]\n", "PEAK_LEARNING_RATE = 8e-4\n", "WARMUP_STEPS = 120 # This is a flex, Kaiming + Pre-LN + AdamW is so stable that we don't even need warmups\n", "WEIGHT_DECAY = 0.01\n", "\n", "# --- Regularization Config ---\n", "LABEL_SMOOTHING_EPSILON = 0.1\n", "\n", "# --- Other Constants ---\n", "DRIVE_BASE_PATH = \"/content/drive/MyDrive/AIAYN\"\n", "PREBATCHED_REPO_ID = \"prism-lab/wmt14-de-en-prebatched-w4\"\n", "ORIGINAL_BUCKETED_REPO_ID = \"prism-lab/wmt14-de-en-bucketed-w4\"\n", "MODEL_CHECKPOINT = \"Helsinki-NLP/opus-mt-de-en\" # We only use its tokenizer\n" ] }, { "cell_type": "markdown", "source": [ "## DATALOADERS" ], "metadata": { "id": "W5l1HHRFXxPA" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FA5SqFzeMrpK" }, "outputs": [], "source": [ "\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "from transformers import AutoTokenizer\n", "from datasets import load_dataset\n", "import math\n", "import os\n", "from tqdm.auto import tqdm\n", "from torchmetrics.text import BLEUScore\n", "from torch.utils.tensorboard import SummaryWriter\n", "import random\n", "import numpy as np\n", "import torch\n", "from transformers import get_cosine_schedule_with_warmup\n", "from typing import List\n", "from transformers import AutoModel\n", "\n", "\n", "def set_seed(seed_value=5):\n", " \"\"\"Sets the seed for reproducibility.\"\"\"\n", " random.seed(seed_value)\n", " np.random.seed(seed_value)\n", " torch.manual_seed(seed_value)\n", " torch.cuda.manual_seed_all(seed_value)\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", "\n", "SEED = 116\n", "set_seed(SEED)\n", "print(f\"Reproducibility seed set to {SEED}\")\n", "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", "\n", "torch.use_deterministic_algorithms(True)\n", "\n", "print(\"--- Loading Modernized Configuration ---\")\n", "def seed_worker(worker_id):\n", " \"\"\"\n", " DataLoader worker'ları için seed ayarlama fonksiyonu.\n", " Her worker'ın farklı ama deterministik bir seed'e sahip olmasını sağlar.\n", " \"\"\"\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)\n", "\n", "torch.set_float32_matmul_precision('high')\n", "print(\"✅ PyTorch matmul precision set to 'high'\")\n", "\n", "# --- Device Setup ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {device}\")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)\n", "\n", "VOCAB_SIZE = len(tokenizer)\n", "print(f\"Vocab size: {VOCAB_SIZE}\")\n", "\n", "\n", "# DATA LOADING & PREPARATION\n", "from transformers import DataCollatorForSeq2Seq\n", "\n", "standard_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)\n", "\n", "class PreBatchedCollator:\n", " def __init__(self, original_dataset_split):\n", " self.original_dataset = original_dataset_split\n", "\n", " def __call__(self, features: List[dict]) -> dict:\n", " # 'features' will be a list of size 1, e.g., [{'batch_indices': [10, 5, 123]}]\n", " batch_indices = features[0]['batch_indices']\n", "\n", " # This returns a \"Dictionary of Lists\"\n", " # e.g., {'input_ids': [[...], [...]], 'labels': [[...], [...]]}\n", " dict_of_lists = self.original_dataset[batch_indices]\n", "\n", " # --- THE FIX ---\n", " # We must convert it to a \"List of Dictionaries\" for the standard collator.\n", " # e.g., [{'input_ids': [...], 'labels': [...]}, {'input_ids': [...], 'labels': [...]}]\n", " list_of_dicts = []\n", " keys = dict_of_lists.keys()\n", " num_samples = len(dict_of_lists['input_ids'])\n", "\n", " for i in range(num_samples):\n", " list_of_dicts.append({key: dict_of_lists[key][i] for key in keys})\n", " # --- END OF FIX ---\n", "\n", " # Now, pass the correctly formatted data to the standard collator\n", " return standard_collator(list_of_dicts)\n", "\n", "print(f\"Loading pre-batched dataset from: {PREBATCHED_REPO_ID}\")\n", "prebatched_datasets = load_dataset(PREBATCHED_REPO_ID)\n", "\n", "print(f\"Loading original samples from: {ORIGINAL_BUCKETED_REPO_ID}\")\n", "original_datasets = load_dataset(ORIGINAL_BUCKETED_REPO_ID)\n", "train_collator = PreBatchedCollator(original_datasets[\"train\"])\n", "\n", "# --- The New, Simple DataLoader ---\n", "# No more custom sampler!\n", "g = torch.Generator()\n", "g.manual_seed(SEED)\n", "\n", "train_dataloader = DataLoader(\n", " prebatched_datasets[\"train\"],\n", " batch_size=1, # Each row is already a batch\n", " shuffle=True, # Shuffle the pre-calculated batches every epoch\n", " num_workers=0,\n", " collate_fn=train_collator,\n", " pin_memory=True,\n", " worker_init_fn=seed_worker,\n", " generator=g,\n", ")\n", "\n", "# Validation loader remains the same, using the original data\n", "EVAL_BATCH_SIZE = 64\n", "val_dataloader = DataLoader(\n", " original_datasets[\"validation\"],\n", " batch_size=EVAL_BATCH_SIZE,\n", " collate_fn=standard_collator,\n", " num_workers=0,\n", " pin_memory=True,\n", " worker_init_fn=seed_worker,\n", " generator=g,\n", ")\n", "\n", "print(f\"Train Dataloader is now a simple iterator over pre-calculated batches.\")\n", "\n", "# --- SANITY CHECK ---\n", "print(\"\\n--- Running Sanity Check on new DataLoader ---\")\n", "train_dataloader.generator.manual_seed(SEED) # Reset generator for check\n", "temp_iterator = iter(train_dataloader)\n", "print(\"Shapes of first 5 batches:\")\n", "for i in range(5):\n", " batch = next(temp_iterator)\n", " print(f\" Batch {i+1}: input_ids shape = {batch['input_ids'].shape}\")\n", "print(\"--- Sanity Check Complete ---\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "cS4JvJGRhClv" }, "source": [ "## Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SMhlM0YvO1A7" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import math\n", "\n", "class PositionalEncoding(nn.Module):\n", " \"\"\"Injects positional information into the input embeddings.\"\"\"\n", " def __init__(self, d_model: int, max_len: int = 5000):\n", " super().__init__()\n", " position = torch.arange(max_len).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n", " pe = torch.zeros(1, max_len, d_model)\n", " pe[0, :, 0::2] = torch.sin(position * div_term)\n", " pe[0, :, 1::2] = torch.cos(position * div_term)\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self, x: torch.Tensor):\n", " # x shape: [batch_size, seq_len, d_model]\n", " return x + self.pe[:, :x.size(1)]\n", "\n", "class FeedForward(nn.Module):\n", " \"\"\"A standard two-layer feed-forward network with a ReLU activation.\"\"\"\n", " def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):\n", " super().__init__()\n", " self.ffn = nn.Sequential(\n", " nn.Linear(d_model, dff),\n", " nn.ReLU(),\n", " nn.Linear(dff, d_model),\n", " nn.Dropout(dropout_rate)\n", " )\n", " def forward(self, x: torch.Tensor):\n", " return self.ffn(x)\n", "\n", "class StandardTransformer(nn.Module):\n", " def __init__(self, num_encoder_layers, num_decoder_layers, num_heads, d_model, dff, vocab_size, max_length, dropout):\n", " super().__init__()\n", " self.d_model = d_model\n", " self.embedding = nn.Embedding(vocab_size, d_model)\n", " self.pos_encoder = PositionalEncoding(d_model, max_length)\n", " self.dropout = nn.Dropout(dropout)\n", " encoder_layer = nn.TransformerEncoderLayer(\n", " d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX\n", " )\n", " self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)\n", "\n", " decoder_layer = nn.TransformerDecoderLayer(\n", " d_model, num_heads, dff, dropout, batch_first=True, norm_first=True # <-- THE FIX\n", " )\n", " self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)\n", "\n", " self.final_linear = nn.Linear(d_model, vocab_size)\n", " self.final_linear.weight = self.embedding.weight\n", "\n", " def forward(self, src, tgt, src_padding_mask, tgt_padding_mask, memory_key_padding_mask, tgt_mask):\n", "\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)\n", " src_emb_pos = self.dropout(self.pos_encoder(src_emb))\n", " tgt_emb_pos = self.dropout(self.pos_encoder(tgt_emb))\n", "\n", " memory = self.encoder(src_emb_pos, src_key_padding_mask=src_padding_mask)\n", " decoder_output = self.decoder(\n", " tgt=tgt_emb_pos, memory=memory, tgt_mask=tgt_mask,\n", " tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_key_padding_mask\n", " )\n", " return self.final_linear(decoder_output)\n", "\n", "\n", " def create_masks(self, src, tgt):\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", " tgt_padding_mask = (tgt == tokenizer.pad_token_id)\n", " # Creates a square causal mask for the decoder. This prevents any token from attending to future tokens. With this way model can not cheat.\n", " tgt_mask = nn.Transformer.generate_square_subsequent_mask(\n", " sz=tgt.size(1),\n", " device=src.device,\n", " dtype=torch.bool\n", " )\n", " return src_padding_mask, tgt_padding_mask, src_padding_mask, tgt_mask\n", "\n", " @torch.no_grad()\n", " def generate(self, src: torch.Tensor, max_length: int, num_beams: int = 5) -> torch.Tensor:\n", " self.eval()\n", " src_padding_mask = (src == tokenizer.pad_token_id)\n", "\n", " src_emb = self.embedding(src) * math.sqrt(self.d_model)\n", " src_emb_pos = self.pos_encoder(src_emb)\n", " memory = self.encoder(self.dropout(src_emb_pos), src_key_padding_mask=src_padding_mask)\n", "\n", " batch_size = src.shape[0]\n", " memory = memory.repeat_interleave(num_beams, dim=0)\n", " memory_key_padding_mask = src_padding_mask.repeat_interleave(num_beams, dim=0)\n", "\n", " initial_token = tokenizer.pad_token_id\n", " beams = torch.full((batch_size * num_beams, 1), initial_token, dtype=torch.long, device=src.device)\n", "\n", " beam_scores = torch.zeros(batch_size * num_beams, device=src.device)\n", " finished_beams = torch.zeros(batch_size * num_beams, dtype=torch.bool, device=src.device)\n", " for _ in range(max_length - 1):\n", " if finished_beams.all(): break\n", " tgt_mask = nn.Transformer.generate_square_subsequent_mask(beams.size(1)).to(src.device)\n", " tgt_emb = self.embedding(beams) * math.sqrt(self.d_model) # FIX HERE TOO\n", " tgt_emb_pos = self.pos_encoder(tgt_emb)\n", " decoder_output = self.decoder(tgt=self.dropout(tgt_emb_pos), memory=memory, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)\n", " logits = self.final_linear(decoder_output[:, -1, :])\n", " log_probs = F.log_softmax(logits, dim=-1)\n", " log_probs[:, tokenizer.pad_token_id] = -torch.inf\n", " if finished_beams.any(): log_probs[finished_beams, tokenizer.eos_token_id] = 0\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", " if _ == 0:\n", " total_scores = total_scores.view(batch_size, num_beams, -1)\n", " total_scores[:, 1:, :] = -torch.inf # Sadece ilk beam'in başlamasına izin ver\n", " total_scores = total_scores.view(batch_size * num_beams, -1)\n", " else:\n", " total_scores = beam_scores.unsqueeze(1) + log_probs\n", " total_scores = total_scores.view(batch_size, -1)\n", " top_scores, top_indices = torch.topk(total_scores, k=num_beams, dim=1)\n", " beam_indices = top_indices // log_probs.shape[-1]; token_indices = top_indices % log_probs.shape[-1]\n", " batch_indices = torch.arange(batch_size, device=src.device).unsqueeze(1)\n", " effective_indices = (batch_indices * num_beams + beam_indices).view(-1)\n", " beams = beams[effective_indices]\n", " beams = torch.cat([beams, token_indices.view(-1, 1)], dim=1)\n", " beam_scores = top_scores.view(-1)\n", " finished_beams = finished_beams | (beams[:, -1] == tokenizer.eos_token_id)\n", " final_beams = beams.view(batch_size, num_beams, -1)\n", " final_scores = beam_scores.view(batch_size, num_beams)\n", " normalized_scores = final_scores / (final_beams != tokenizer.pad_token_id).sum(-1).float().clamp(min=1)\n", " best_beams = final_beams[torch.arange(batch_size), normalized_scores.argmax(1), :]\n", " self.train()\n", " return best_beams\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3QGBtTvj6Jrp" }, "outputs": [], "source": [ "# ==============================================================================\n", "# --- Model Analysis & Parameter Counting ---\n", "# ==============================================================================\n", "from collections import defaultdict\n", "\n", "def count_parameters_correctly(model):\n", " \"\"\"\n", " Counts trainable parameters, correctly handling tied weights (e.g., embeddings).\n", " \"\"\"\n", " seen_params = set()\n", " total_params = 0\n", " for param in model.parameters():\n", " if param.requires_grad:\n", " param_id = id(param)\n", " if param_id not in seen_params:\n", " seen_params.add(param_id)\n", " total_params += param.numel()\n", " return total_params\n", "\n", "# --- Instantiate the model to analyze it ---\n", "print(\"--- Analyzing Model Parameters ---\")\n", "model_to_analyze = StandardTransformer(\n", " num_encoder_layers=NUM_ENCODER_LAYERS,\n", " num_decoder_layers=NUM_DECODER_LAYERS,\n", " num_heads=NUM_HEADS,\n", " d_model=D_MODEL,\n", " dff=D_FF,\n", " vocab_size=VOCAB_SIZE,\n", " max_length=MAX_LENGTH,\n", " dropout=DROPOUT\n", ")\n", "\n", "# --- Perform the counting and display results ---\n", "correct_total = count_parameters_correctly(model_to_analyze)\n", "pytorch_naive_total = sum(p.numel() for p in model_to_analyze.parameters() if p.requires_grad)\n", "\n", "print(f\"Total Trainable Parameters (Correctly Counted): {correct_total:,}\")\n", "print(f\"PyTorch's Naive Count (sum(p.numel())): {pytorch_naive_total:,}\")\n", "if pytorch_naive_total != correct_total:\n", " print(f\"Note: The naive count is higher due to double-counting the tied embedding weights.\")\n", "\n", "del model_to_analyze # Clean up memory\n", "print(\"--- Analysis Complete ---\\n\")" ] }, { "cell_type": "markdown", "metadata": { "id": "Zd3AFTmhrCJq" }, "source": [ "## Functions (Loss, Eval etc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Te1qTyUKrDEd" }, "outputs": [], "source": [ "\n", "translation_loss_fn = nn.CrossEntropyLoss(\n", " ignore_index=-100, # We don't calculate loss for pad tokens. Pad tokens are replaced with -100 by DataCollatorForSeq2Seq.\n", " label_smoothing=LABEL_SMOOTHING_EPSILON\n", ")\n", "def calculate_combined_loss(model_outputs, target_labels):\n", " \"\"\"Calculates the loss based on the model's output structure.\"\"\"\n", " logits = model_outputs\n", " translation_loss = translation_loss_fn(logits.reshape(-1, logits.shape[-1]), target_labels.reshape(-1))\n", " loss_dict = {'total': translation_loss.item()}\n", " return translation_loss, loss_dict\n", "\n", "def evaluate(model, dataloader, device):\n", " \"\"\"Evaluates the model using beam search decoding.\"\"\"\n", " bleu_metric = BLEUScore()\n", "\n", "\n", " orig_model = getattr(model, '_orig_mod', model)\n", " orig_model.eval()\n", "\n", " for batch in tqdm(dataloader, desc=\"Evaluating\", leave=False):\n", " input_ids = batch['input_ids'].to(device)\n", " labels = batch['labels']\n", "\n", " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n", "\n", " pred_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", " labels[labels == -100] = tokenizer.pad_token_id\n", " ref_texts = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", " bleu_metric.update(pred_texts, [[ref] for ref in ref_texts])\n", "\n", " orig_model.train()\n", " return bleu_metric.compute().item()\n", "\n", "def generate_sample_translations(model, device, sentences_de):\n", " \"\"\"Generates and prints sample translations using beam search.\"\"\"\n", " print(\"\\n--- Generating Sample Translations (with Beam Search) ---\")\n", " orig_model = getattr(model, '_orig_mod', model)\n", " orig_model.eval()\n", "\n", " inputs = tokenizer(sentences_de, return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_LENGTH)\n", " input_ids = inputs.input_ids.to(device)\n", " generated_ids = orig_model.generate(input_ids, max_length=MAX_LENGTH, num_beams=5)\n", "\n", " translations = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n", " for src, out in zip(sentences_de, translations):\n", " print(f\" DE Source: {src}\")\n", " print(f\" EN Output: {out}\")\n", " print(\"-\" * 20)\n", " orig_model.train()\n", "\n", "sample_sentences_de_for_tracking = [\n", " \"Eine Katze sitzt auf der Matte.\",\n", " \"Ein Mann in einem roten Hemd liest ein Buch.\",\n", " \"Was ist die Hauptstadt von Deutschland?\",\n", " \"Ich gehe ins Kino, weil der Film sehr gut ist.\",\n", "]\n", "\n", "def init_other_linear_weights(m):\n", " if isinstance(m, nn.Linear):\n", " # The 'is not' check correctly skips the final_linear layer,\n", " # leaving its weights tied to the correctly initialized embeddings.\n", " if m is not getattr(model, '_orig_mod', model).final_linear:\n", " nn.init.xavier_uniform_(m.weight)\n", " if m.bias is not None:\n", " nn.init.zeros_(m.bias)\n", "\n", "\n" ] }, { "cell_type": "code", "source": [ "import json\n", "import os\n", "import subprocess\n", "import torch\n", "import hashlib\n", "import sys\n", "import shutil\n", "\n", "# This logger will be configured and used in the main training script\n", "import logging\n", "logger = logging.getLogger(__name__)\n", "\n", "\n", "def log_to_run_specific_file(run_dir):\n", " run_log_path = os.path.join(run_dir, \"run_log.txt\")\n", " file_handler = logging.FileHandler(run_log_path)\n", " file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))\n", " logger.addHandler(file_handler)\n", " return file_handler\n", "\n", "def log_configurations(log_dir, config_vars):\n", " # (Same as your provided function)\n", " config_path = os.path.join(log_dir, \"config.json\")\n", " try:\n", " with open(config_path, 'w') as f:\n", " serializable_configs = {k: v for k, v in config_vars.items() if isinstance(v, (int, float, str, bool, list, dict, type(None)))}\n", " json.dump(serializable_configs, f, indent=4)\n", " logger.info(f\"Configurations saved to {config_path}\")\n", " except Exception as e:\n", " logger.error(f\"Could not save configurations: {e}\")\n", "\n", "def log_environment(log_dir):\n", " # (Same as your provided function)\n", " env_path = os.path.join(log_dir, \"environment.txt\")\n", " try:\n", " with open(env_path, 'w') as f:\n", " f.write(f\"--- Timestamp (UTC): {datetime.datetime.utcnow().isoformat()} ---\\n\")\n", " f.write(f\"Python Version: {sys.version}\\n\")\n", " f.write(f\"PyTorch Version: {torch.__version__}\\n\")\n", " f.write(f\"CUDA Available: {torch.cuda.is_available()}\\n\")\n", " if torch.cuda.is_available():\n", " f.write(f\"CUDA Version: {torch.version.cuda}\\n\")\n", " f.write(f\"CuDNN Version: {torch.backends.cudnn.version()}\\n\")\n", " f.write(f\"Number of GPUs: {torch.cuda.device_count()}\\n\")\n", " f.write(f\"GPU Name: {torch.cuda.get_device_name(0)}\\n\")\n", " f.write(\"\\n--- Full pip freeze ---\\n\")\n", " result = subprocess.run([sys.executable, '-m', 'pip', 'freeze'], stdout=subprocess.PIPE, text=True, check=True)\n", " f.write(result.stdout)\n", " logger.info(f\"Environment info saved to {env_path}\")\n", " except Exception as e:\n", " logger.error(f\"Could not save environment info: {e}\")\n", "\n", "def log_code_snapshot(log_dir, script_path):\n", " # NOTE: In Colab, you must save your notebook as a .py file for this to work.\n", " # For example, file -> \"Save a copy as .py\"\n", " code_dir = os.path.join(log_dir, \"code_snapshot\")\n", " os.makedirs(code_dir, exist_ok=True)\n", " if script_path and os.path.exists(script_path):\n", " try:\n", " shutil.copy(script_path, os.path.join(code_dir, os.path.basename(script_path)))\n", " logger.info(f\"Copied script '{script_path}' to snapshot directory for verification.\")\n", " except Exception as e:\n", " logger.error(f\"Could not copy script for snapshot: {e}\")\n", " else:\n", " logger.warning(f\"Code Snapshot: Script path '{script_path}' not found. SKIPPING.\")\n", "\n", "def get_file_hash(filepath):\n", " # (Same as your provided function)\n", " sha256_hash = hashlib.sha256()\n", " try:\n", " with open(filepath, \"rb\") as f:\n", " for byte_block in iter(lambda: f.read(4096), b\"\"):\n", " sha256_hash.update(byte_block)\n", " return sha256_hash.hexdigest()\n", " except Exception as e:\n", " logger.error(f\"Could not generate hash for {filepath}: {e}\")\n", " return None\n", "\n", "def create_checksum_file(run_dir, artifacts_dict):\n", " checksum_file_path = os.path.join(run_dir, \"checksums.sha256\")\n", " logger.info(f\"--- Creating digital fingerprints for key artifacts ---\")\n", " with open(checksum_file_path, \"w\") as f:\n", " f.write(f\"SHA256 Checksums for run: {os.path.basename(run_dir)}\\n\")\n", " for name, path in artifacts_dict.items():\n", " if path and os.path.exists(path):\n", " file_hash = get_file_hash(path)\n", " if file_hash:\n", " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n", " logger.info(log_message)\n", " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n", " else:\n", " logger.warning(f\" - Skipped hashing '{name}', file not found: {path}\")\n", " logger.info(f\"Checksums saved to {checksum_file_path}\")\n", "\n", "def init_weights_kaiming(m):\n", " \"\"\"\n", " Applies Kaiming He initialization to Linear layers.\n", " This is the standard, superior way to initialize deep Transformers.\n", " NOTE: We will handle the Embedding layer separately.\n", " \"\"\"\n", "\n", " if isinstance(m, nn.Linear):\n", " nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) # a=sqrt(5) mimics default PyTorch for LeakyReLU\n", " if m.bias is not None:\n", " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)\n", " bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0\n", " nn.init.uniform_(m.bias, -bound, bound)\n" ], "metadata": { "id": "YwPXbSwR50I2" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ijTUk5dHu494" }, "source": [ "## Training Loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pyHZ1moluyA2" }, "outputs": [], "source": [ "\n", "if __name__ == '__main__':\n", "\n", " experiment_name = f\"{MODEL_CHOICE}\"\n", " CURRENT_RUN_DIR = os.path.join(DRIVE_BASE_PATH, experiment_name) # Single run directory\n", " SAVE_DIR = os.path.join(CURRENT_RUN_DIR, \"models\")\n", " LOG_DIR_TENSORBOARD = os.path.join(CURRENT_RUN_DIR, \"tensorboard_logs\")\n", " LOG_FILE_TXT = os.path.join(CURRENT_RUN_DIR, \"run_log.txt\")\n", "\n", " os.makedirs(SAVE_DIR, exist_ok=True)\n", " os.makedirs(LOG_DIR_TENSORBOARD, exist_ok=True)\n", "\n", " logging.basicConfig(\n", " level=logging.INFO,\n", " format='%(asctime)s [%(levelname)s] %(message)s',\n", " handlers=[\n", " logging.FileHandler(LOG_FILE_TXT),\n", " logging.StreamHandler(sys.stdout)\n", " ],\n", " force=True\n", " )\n", " logger = logging.getLogger(__name__)\n", " writer = SummaryWriter(LOG_DIR_TENSORBOARD)\n", "\n", " logger.info(f\"--- LAUNCHING EXPERIMENT: {experiment_name} ---\")\n", "\n", " all_configs = {k: v for k, v in globals().items() if k.isupper()}\n", " all_configs['TARGET_TRAINING_STEPS'] = TARGET_TRAINING_STEPS\n", " all_configs['VALIDATION_SCHEDULE'] = VALIDATION_SCHEDULE\n", " log_configurations(CURRENT_RUN_DIR, all_configs)\n", " log_environment(CURRENT_RUN_DIR)\n", " log_code_snapshot(CURRENT_RUN_DIR, \"your_notebook_name.ipynb\") # Remember to update this filename\n", "\n", " set_seed(SEED)\n", " logger.info(f\"Reproducibility seed set to {SEED}\")\n", "\n", " logger.info(f\"--- Initializing StandardTransformer ---\")\n", " model = StandardTransformer(\n", " num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS,\n", " num_heads=NUM_HEADS, d_model=D_MODEL, dff=D_FF, vocab_size=VOCAB_SIZE,\n", " max_length=MAX_LENGTH, dropout=DROPOUT\n", " )\n", "\n", " # 3. WEIGHT INITIALIZATION STRATEGY\n", " model.apply(init_weights_kaiming)\n", " logger.info(\" Applied Kaiming Uniform initialization to all linear layers.\")\n", "\n", " # Removed the if/else logic, only the \"from-scratch\" path remains\n", " logger.info(\"--- Initializing embedding layer from scratch ---\")\n", " nn.init.normal_(model.embedding.weight, mean=0.0, std=0.02)\n", " logger.info(\" Initialized embedding map with Normal(0, 0.02).\")\n", "\n", " # Tie weights AFTER all initialization is complete.\n", " model.final_linear.weight = model.embedding.weight\n", "\n", " model.to(device)\n", " logger.info(f\"Model is ready on {device}.\")\n", "\n", " # 4. SETUP OPTIMIZER, SCHEDULER, AND SCALER\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=PEAK_LEARNING_RATE, betas=(0.9, 0.98),\n", " eps=1e-9, weight_decay=WEIGHT_DECAY)\n", " scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=WARMUP_STEPS,\n", " num_training_steps=TARGET_TRAINING_STEPS) # Use total steps\n", " scaler = torch.cuda.amp.GradScaler()\n", "\n", " # 5. TRAINING LOOP\n", " model.train()\n", " global_step = 0 # Renamed from global_step_this_iteration\n", " best_bleu = 0.0 # Renamed from best_bleu_this_iteration\n", " LAST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"last.pt\")\n", " BEST_CHECKPOINT_PATH = os.path.join(SAVE_DIR, \"best.pt\")\n", "\n", " # Simplified progress bar\n", " progress_bar = tqdm(total=TARGET_TRAINING_STEPS, desc=\"Total Progress\")\n", " training_complete = False\n", "\n", " for epoch in range(200): # This can be a large number, the step check will stop it\n", " if training_complete: break\n", "\n", " # --- Simplified generator seed ---\n", " train_dataloader.generator.manual_seed(SEED + epoch)\n", "\n", " for batch in train_dataloader:\n", " if global_step >= TARGET_TRAINING_STEPS: # Check against total steps\n", " training_complete = True\n", " break\n", "\n", " optimizer.zero_grad(set_to_none=True)\n", " input_ids = batch['input_ids'].to(device, non_blocking=True)\n", " labels = batch['labels'].to(device, non_blocking=True)\n", " decoder_start_token = torch.full((labels.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=device)\n", " decoder_input_ids = torch.cat([decoder_start_token, labels[:, :-1]], dim=1)\n", " decoder_input_ids[decoder_input_ids == -100] = tokenizer.pad_token_id\n", " target_labels = labels\n", "\n", " src_padding_mask, tgt_padding_mask, mem_key_padding_mask, tgt_mask = model.create_masks(input_ids, decoder_input_ids)\n", " tgt_padding_mask[:, 0] = False\n", "\n", " with torch.autocast(device_type=\"cuda\", dtype=torch.float16):\n", " model_outputs = model(src=input_ids, tgt=decoder_input_ids, src_padding_mask=src_padding_mask,\n", " tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=mem_key_padding_mask,\n", " tgt_mask=tgt_mask)\n", " loss, loss_components = calculate_combined_loss(model_outputs, target_labels)\n", "\n", " scaler.scale(loss).backward()\n", " scaler.unscale_(optimizer)\n", " total_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " scaler.step(optimizer)\n", " scaler.update()\n", " scheduler.step()\n", " global_step += 1 # Use main global_step\n", " progress_bar.update(1)\n", " lr = scheduler.get_last_lr()[0]\n", "\n", " if global_step % 20 == 0:\n", " writer.add_scalar('train/loss', loss.item(), global_step)\n", " writer.add_scalar('train/learning_rate', lr, global_step)\n", " writer.add_scalar('train/gradient_norm', total_grad_norm.item(), global_step)\n", " progress_bar.set_postfix(loss=loss.item(), grad_norm=f\"{total_grad_norm.item():.2f}\", lr=f\"{lr:.2e}\")\n", "\n", " if global_step in VALIDATION_SCHEDULE:\n", " # --- Simplified logging message ---\n", " logger.info(f\"\\n--- Validation at Step {global_step} ---\")\n", " bleu_score = evaluate(model, val_dataloader, device)\n", " writer.add_scalar('validation/bleu', bleu_score, global_step)\n", " logger.info(f\"Validation BLEU: {bleu_score:.4f} (Best: {best_bleu:.4f})\")\n", " generate_sample_translations(model, device, sample_sentences_de_for_tracking)\n", "\n", " if bleu_score > best_bleu:\n", " best_bleu = bleu_score\n", " logger.info(f\" New best BLEU! Saving best model...\")\n", " torch.save(model.state_dict(), BEST_CHECKPOINT_PATH)\n", "\n", " model.train()\n", "\n", " progress_bar.close()\n", " writer.close()\n", " logger.info(f\"--- Training finished after {global_step} steps ---\")\n", "\n", " # --- 6. SAVE FINAL STATE ---\n", " torch.save({'global_step': global_step, 'model_state_dict': model.state_dict()},\n", " LAST_CHECKPOINT_PATH)\n", " logger.info(f\"Saved final state to: {LAST_CHECKPOINT_PATH}\")\n", "\n", " # --- Removed the previous_iteration_checkpoint_path line ---\n", "\n", " # --- 7. CREATE DIGITAL FINGERPRINTS ---\n", " logger.info(\"--- Creating digital fingerprints for key artifacts ---\")\n", " files_to_hash = {\n", " \"Last Model\": LAST_CHECKPOINT_PATH,\n", " \"Best Model\": BEST_CHECKPOINT_PATH,\n", " \"Text Log\": LOG_FILE_TXT,\n", " }\n", "\n", " try:\n", " tb_log_file = [f for f in os.listdir(LOG_DIR_TENSORBOARD) if 'tfevents' in f][0]\n", " files_to_hash[\"TensorBoard Log\"] = os.path.join(LOG_DIR_TENSORBOARD, tb_log_file)\n", " except IndexError:\n", " logger.warning(\"Could not find TensorBoard events file to hash.\")\n", "\n", " checksum_file_path = os.path.join(CURRENT_RUN_DIR, \"checksums.sha256\")\n", " with open(checksum_file_path, \"w\") as f:\n", " # --- Simplified checksums file ---\n", " f.write(f\"SHA256 Checksums for run: {experiment_name}\\n\")\n", " f.write(\"=\"*50 + \"\\n\")\n", " for name, path in files_to_hash.items():\n", " if path and os.path.exists(path):\n", " file_hash = get_file_hash(path)\n", " if file_hash:\n", " log_message = f\" - {name} ({os.path.basename(path)}): {file_hash}\"\n", " logger.info(log_message)\n", " f.write(f\"{file_hash} {os.path.basename(path)}\\n\")\n", " else:\n", " logger.warning(f\" - Skipped hashing for '{name}', file not found: {path}\")\n", "\n", " logger.info(f\"Checksums saved to {checksum_file_path}\")\n", "\n", " print(\"\\n\\n\" + \"*\"*80)\n", " print(\" EXPERIMENT COMPLETE \")\n", " print(\"*\"*80)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tqDiOyy18clU" }, "outputs": [], "source": [ "# TENSORBOARD VISUALIZATION\n", "\n", "%load_ext tensorboard\n", "\n", "TENSORBOARD_BASE_DIR = os.path.join(DRIVE_BASE_PATH)\n", "\n", "%tensorboard --logdir \"{TENSORBOARD_BASE_DIR}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "eI0-qVlWVVpx" }, "source": [ "## End" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "provenance": [], "collapsed_sections": [ "cS4JvJGRhClv" ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }