{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 06. Code CPT (Continual Pre-Training)\n", "\n", "Injects Python code generation capability into the pretrained base model\n", "by continuing training on `bigcode/starcoderdata` (Python subset).\n", "\n", "**Strategy (following Code Llama):**\n", "- Load pretrained model weights (from `03_training`)\n", "- Mix **80% code** (StarCoder Python) + **20% general text** (FineWeb-Edu)\n", "- Lower learning rate (1e-4 vs 3e-4) to preserve existing representations\n", "- Fresh optimizer (no momentum carry-over from pretraining)\n", "\n", "**Expected outcome:**\n", "- The model learns Python syntax, indentation, and common patterns\n", "- General language ability is preserved via data mixing\n", "- Fibonacci / simple code generation becomes possible" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install wandb -q" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "try:\n", " import google.colab\n", " from google.colab import drive\n", " drive.mount('/content/drive')\n", " project_path = '/content/drive/MyDrive/Colab Notebooks/LLM-1B-Lab'\n", " sys.path.append(project_path)\n", "except ImportError:\n", " sys.path.insert(0, '..')\n", "\n", "from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n", "from llm_lab.model import LLMModel\n", "from llm_lab.data import setup_cpt_data_pipeline\n", "from llm_lab.training import start_cpt\n", "from llm_lab.utils import auto_configure, get_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Configuration\n", "\n", "Use the CPT presets, which set appropriate LR, data mixing, and checkpoint paths." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# --- Model configuration (same architecture as pretraining) ---\n", "model_config = ModelConfig.base_1b()\n", "\n", "# --- Data configuration (80% code + 20% general) ---\n", "data_config = DataConfig.code_cpt()\n", "data_config.max_seq_len = model_config.max_seq_len\n", "data_config.batch_size = 4\n", "\n", "# --- Training configuration (lower LR, fresh optimizer) ---\n", "train_config = TrainConfig.code_cpt_1b()\n", "train_config.wandb_dir = \"/content/drive/MyDrive/wandb_logs\"\n", "\n", "# --- Path to the pretrained base checkpoint ---\n", "PRETRAINED_CKPT_DIR = \"/content/drive/MyDrive/llm-1b-lab/checkpoints\"\n", "\n", "print(f\"Effective batch size: {train_config.effective_batch_size}\")\n", "print(f\"Total CPT steps: {train_config.total_steps:,}\")\n", "print(f\"Estimated CPT tokens: {train_config.total_steps * train_config.effective_batch_size * model_config.max_seq_len / 1e9:.1f}B\")\n", "print(f\"Peak LR: {train_config.learning_rate}\")\n", "print(f\"Data mix: {data_config.mix_weights}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Model + Mixed Data Pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model (weights will be loaded from pretrained checkpoint)\n", "model = LLMModel(model_config)\n", "print(f\"Model parameters: {model.count_parameters():,}\")\n", "\n", "# Mixed data pipeline: StarCoder Python (80%) + FineWeb-Edu (20%)\n", "tokenizer, train_dl, val_dl = setup_cpt_data_pipeline(config=data_config)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Start Code CPT\n", "\n", "This will:\n", "1. Load the pretrained model weights from the base checkpoint\n", "2. Create a **fresh optimizer** (AdamW) with lower LR\n", "3. Train on the mixed code + general data\n", "\n", "If a CPT checkpoint already exists (from a previous interrupted session),\n", "it will automatically resume from that checkpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = start_cpt(\n", " model=model,\n", " train_dataloader=train_dl,\n", " val_dataloader=val_dl,\n", " config=train_config,\n", " pretrained_checkpoint_dir=PRETRAINED_CKPT_DIR,\n", " seq_len=model_config.max_seq_len,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Quick Code Generation Test\n", "\n", "Test whether the model can now generate Python code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "device = get_device()\n", "model.eval()\n", "\n", "code_prompts = [\n", " \"def fibonacci(n):\",\n", " \"def factorial(n):\",\n", " \"# Python function to sort a list\\ndef\",\n", " \"class Stack:\\n def __init__(self):\",\n", "]\n", "\n", "for prompt in code_prompts:\n", " print(f\"{'='*60}\")\n", " print(f\"PROMPT: {prompt}\")\n", " print(f\"{'-'*60}\")\n", " input_ids = tokenizer.encode(prompt, add_special_tokens=False)\n", " input_tensor = torch.tensor([input_ids], device=device)\n", " with torch.no_grad():\n", " output = model.generate(\n", " input_tensor,\n", " max_new_tokens=128,\n", " temperature=0.7,\n", " top_p=0.9,\n", " )\n", " generated = tokenizer.decode(output[0].tolist())\n", " print(generated)\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Full Evaluation (Optional)\n", "\n", "Run the full evaluation suite to compare with the base pretrained model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from llm_lab.evaluation import run_evaluation\n", "\n", "report = run_evaluation(\n", " model=model,\n", " tokenizer=tokenizer,\n", " val_dataloader=val_dl,\n", " metrics_history=trainer.metrics.history,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "**What to look for:**\n", "- Fibonacci / factorial prompts should produce syntactically valid Python\n", "- Repetition rate should drop significantly (from ~57% to <20%)\n", "- General text perplexity should not degrade too much vs. the base model\n", "- If code quality is poor, consider: (1) more CPT steps, (2) adjust mix ratio, (3) lower LR" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }