File size: 7,229 Bytes
a424729 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 06. Code CPT (Continual Pre-Training)\n",
"\n",
"Injects Python code generation capability into the pretrained base model\n",
"by continuing training on `bigcode/starcoderdata` (Python subset).\n",
"\n",
"**Strategy (following Code Llama):**\n",
"- Load pretrained model weights (from `03_training`)\n",
"- Mix **80% code** (StarCoder Python) + **20% general text** (FineWeb-Edu)\n",
"- Lower learning rate (1e-4 vs 3e-4) to preserve existing representations\n",
"- Fresh optimizer (no momentum carry-over from pretraining)\n",
"\n",
"**Expected outcome:**\n",
"- The model learns Python syntax, indentation, and common patterns\n",
"- General language ability is preserved via data mixing\n",
"- Fibonacci / simple code generation becomes possible"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install wandb -q"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"try:\n",
" import google.colab\n",
" from google.colab import drive\n",
" drive.mount('/content/drive')\n",
" project_path = '/content/drive/MyDrive/Colab Notebooks/LLM-1B-Lab'\n",
" sys.path.append(project_path)\n",
"except ImportError:\n",
" sys.path.insert(0, '..')\n",
"\n",
"from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n",
"from llm_lab.model import LLMModel\n",
"from llm_lab.data import setup_cpt_data_pipeline\n",
"from llm_lab.training import start_cpt\n",
"from llm_lab.utils import auto_configure, get_device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Configuration\n",
"\n",
"Use the CPT presets, which set appropriate LR, data mixing, and checkpoint paths."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# --- Model configuration (same architecture as pretraining) ---\n",
"model_config = ModelConfig.base_1b()\n",
"\n",
"# --- Data configuration (80% code + 20% general) ---\n",
"data_config = DataConfig.code_cpt()\n",
"data_config.max_seq_len = model_config.max_seq_len\n",
"data_config.batch_size = 4\n",
"\n",
"# --- Training configuration (lower LR, fresh optimizer) ---\n",
"train_config = TrainConfig.code_cpt_1b()\n",
"train_config.wandb_dir = \"/content/drive/MyDrive/wandb_logs\"\n",
"\n",
"# --- Path to the pretrained base checkpoint ---\n",
"PRETRAINED_CKPT_DIR = \"/content/drive/MyDrive/llm-1b-lab/checkpoints\"\n",
"\n",
"print(f\"Effective batch size: {train_config.effective_batch_size}\")\n",
"print(f\"Total CPT steps: {train_config.total_steps:,}\")\n",
"print(f\"Estimated CPT tokens: {train_config.total_steps * train_config.effective_batch_size * model_config.max_seq_len / 1e9:.1f}B\")\n",
"print(f\"Peak LR: {train_config.learning_rate}\")\n",
"print(f\"Data mix: {data_config.mix_weights}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Model + Mixed Data Pipeline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create model (weights will be loaded from pretrained checkpoint)\n",
"model = LLMModel(model_config)\n",
"print(f\"Model parameters: {model.count_parameters():,}\")\n",
"\n",
"# Mixed data pipeline: StarCoder Python (80%) + FineWeb-Edu (20%)\n",
"tokenizer, train_dl, val_dl = setup_cpt_data_pipeline(config=data_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Start Code CPT\n",
"\n",
"This will:\n",
"1. Load the pretrained model weights from the base checkpoint\n",
"2. Create a **fresh optimizer** (AdamW) with lower LR\n",
"3. Train on the mixed code + general data\n",
"\n",
"If a CPT checkpoint already exists (from a previous interrupted session),\n",
"it will automatically resume from that checkpoint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = start_cpt(\n",
" model=model,\n",
" train_dataloader=train_dl,\n",
" val_dataloader=val_dl,\n",
" config=train_config,\n",
" pretrained_checkpoint_dir=PRETRAINED_CKPT_DIR,\n",
" seq_len=model_config.max_seq_len,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Quick Code Generation Test\n",
"\n",
"Test whether the model can now generate Python code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"device = get_device()\n",
"model.eval()\n",
"\n",
"code_prompts = [\n",
" \"def fibonacci(n):\",\n",
" \"def factorial(n):\",\n",
" \"# Python function to sort a list\\ndef\",\n",
" \"class Stack:\\n def __init__(self):\",\n",
"]\n",
"\n",
"for prompt in code_prompts:\n",
" print(f\"{'='*60}\")\n",
" print(f\"PROMPT: {prompt}\")\n",
" print(f\"{'-'*60}\")\n",
" input_ids = tokenizer.encode(prompt, add_special_tokens=False)\n",
" input_tensor = torch.tensor([input_ids], device=device)\n",
" with torch.no_grad():\n",
" output = model.generate(\n",
" input_tensor,\n",
" max_new_tokens=128,\n",
" temperature=0.7,\n",
" top_p=0.9,\n",
" )\n",
" generated = tokenizer.decode(output[0].tolist())\n",
" print(generated)\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Full Evaluation (Optional)\n",
"\n",
"Run the full evaluation suite to compare with the base pretrained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llm_lab.evaluation import run_evaluation\n",
"\n",
"report = run_evaluation(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" val_dataloader=val_dl,\n",
" metrics_history=trainer.metrics.history,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"**What to look for:**\n",
"- Fibonacci / factorial prompts should produce syntactically valid Python\n",
"- Repetition rate should drop significantly (from ~57% to <20%)\n",
"- General text perplexity should not degrade too much vs. the base model\n",
"- If code quality is poor, consider: (1) more CPT steps, (2) adjust mix ratio, (3) lower LR"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
} |