Upload main.ipynb with huggingface_hub
Browse files- main.ipynb +46 -14
main.ipynb
CHANGED
|
@@ -521,7 +521,13 @@
|
|
| 521 |
" return self.final_norm(x)\n",
|
| 522 |
"\n",
|
| 523 |
" def forward(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
|
| 524 |
-
" \"\"\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
" hidden = self.forward_hidden(z_t, t)\n",
|
| 526 |
" logits = self.output_proj(hidden)\n",
|
| 527 |
" logits[:, :, self.config.mask_token_id] = -1e9\n",
|
|
@@ -600,7 +606,7 @@
|
|
| 600 |
" alpha_next = self.noise_schedule.alpha(t_next)\n",
|
| 601 |
"\n",
|
| 602 |
" t_batch = torch.full((batch_size,), t_now.item(), device=device)\n",
|
| 603 |
-
" logits = self.
|
| 604 |
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 605 |
"\n",
|
| 606 |
" unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
|
|
@@ -616,7 +622,7 @@
|
|
| 616 |
" is_masked = (x == self.config.mask_token_id)\n",
|
| 617 |
" if is_masked.any():\n",
|
| 618 |
" t_batch = torch.full((batch_size,), 1e-5, device=device)\n",
|
| 619 |
-
" logits = self.
|
| 620 |
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 621 |
" flat_probs = probs.reshape(-1, self.config.vocab_size)\n",
|
| 622 |
" sampled = torch.multinomial(flat_probs, 1).reshape(batch_size, seq_len)\n",
|
|
@@ -633,12 +639,12 @@
|
|
| 633 |
"print(f\"Unique parameters (weight tying): {unique_params / 1e6:.1f}M\")\n",
|
| 634 |
"\n",
|
| 635 |
"# Multi-GPU support (Kaggle T4 x2)\n",
|
|
|
|
| 636 |
"if torch.cuda.device_count() > 1:\n",
|
| 637 |
" print(f\"\\nUsing {torch.cuda.device_count()} GPUs with DataParallel!\")\n",
|
| 638 |
-
"
|
| 639 |
-
" model_unwrapped = model.module\n",
|
| 640 |
"else:\n",
|
| 641 |
-
"
|
| 642 |
"\n",
|
| 643 |
"# Quick memory test\n",
|
| 644 |
"with torch.no_grad():\n",
|
|
@@ -952,17 +958,17 @@
|
|
| 952 |
"\n",
|
| 953 |
"\n",
|
| 954 |
"@torch.no_grad()\n",
|
| 955 |
-
"def generate_samples(
|
| 956 |
" \"\"\"Generate and print text samples.\"\"\"\n",
|
| 957 |
-
"
|
| 958 |
-
" tokens =
|
| 959 |
" texts = []\n",
|
| 960 |
" for i in range(num_samples):\n",
|
| 961 |
" text = tokenizer.decode(tokens[i].cpu().tolist(), skip_special_tokens=True)\n",
|
| 962 |
" texts.append(text)\n",
|
| 963 |
" print(f\"\\n--- Sample {i+1} ---\")\n",
|
| 964 |
" print(text[:500])\n",
|
| 965 |
-
"
|
| 966 |
" return texts\n",
|
| 967 |
"\n",
|
| 968 |
"\n",
|
|
@@ -1867,12 +1873,38 @@
|
|
| 1867 |
" tokens_processed += batch.numel()\n",
|
| 1868 |
"\n",
|
| 1869 |
" with autocast('cuda', dtype=torch.float16):\n",
|
| 1870 |
-
"
|
| 1871 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1872 |
"\n",
|
| 1873 |
" scaler.scale(loss).backward()\n",
|
| 1874 |
-
" step_loss +=
|
| 1875 |
-
" step_acc +=
|
| 1876 |
"\n",
|
| 1877 |
" # Gradient clipping and optimizer step\n",
|
| 1878 |
" scaler.unscale_(optimizer)\n",
|
|
|
|
| 521 |
" return self.final_norm(x)\n",
|
| 522 |
"\n",
|
| 523 |
" def forward(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
|
| 524 |
+
" \"\"\"Forward pass returning hidden states [B, L, D].\n",
|
| 525 |
+
" Used by DataParallel \u2014 logit projection done outside for memory efficiency.\n",
|
| 526 |
+
" For full logits (sampling), use forward_full().\"\"\"\n",
|
| 527 |
+
" return self.forward_hidden(z_t, t)\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" def forward_full(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
|
| 530 |
+
" \"\"\"Full forward pass returning logits [B, L, V]. Used for sampling.\"\"\"\n",
|
| 531 |
" hidden = self.forward_hidden(z_t, t)\n",
|
| 532 |
" logits = self.output_proj(hidden)\n",
|
| 533 |
" logits[:, :, self.config.mask_token_id] = -1e9\n",
|
|
|
|
| 606 |
" alpha_next = self.noise_schedule.alpha(t_next)\n",
|
| 607 |
"\n",
|
| 608 |
" t_batch = torch.full((batch_size,), t_now.item(), device=device)\n",
|
| 609 |
+
" logits = self.forward_full(x, t_batch)\n",
|
| 610 |
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 611 |
"\n",
|
| 612 |
" unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
|
|
|
|
| 622 |
" is_masked = (x == self.config.mask_token_id)\n",
|
| 623 |
" if is_masked.any():\n",
|
| 624 |
" t_batch = torch.full((batch_size,), 1e-5, device=device)\n",
|
| 625 |
+
" logits = self.forward_full(x, t_batch)\n",
|
| 626 |
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 627 |
" flat_probs = probs.reshape(-1, self.config.vocab_size)\n",
|
| 628 |
" sampled = torch.multinomial(flat_probs, 1).reshape(batch_size, seq_len)\n",
|
|
|
|
| 639 |
"print(f\"Unique parameters (weight tying): {unique_params / 1e6:.1f}M\")\n",
|
| 640 |
"\n",
|
| 641 |
"# Multi-GPU support (Kaggle T4 x2)\n",
|
| 642 |
+
"model_unwrapped = model\n",
|
| 643 |
"if torch.cuda.device_count() > 1:\n",
|
| 644 |
" print(f\"\\nUsing {torch.cuda.device_count()} GPUs with DataParallel!\")\n",
|
| 645 |
+
" model_dp = nn.DataParallel(model)\n",
|
|
|
|
| 646 |
"else:\n",
|
| 647 |
+
" model_dp = model\n",
|
| 648 |
"\n",
|
| 649 |
"# Quick memory test\n",
|
| 650 |
"with torch.no_grad():\n",
|
|
|
|
| 958 |
"\n",
|
| 959 |
"\n",
|
| 960 |
"@torch.no_grad()\n",
|
| 961 |
+
"def generate_samples(mdl, tokenizer, num_samples=4, seq_len=128, temperature=0.8):\n",
|
| 962 |
" \"\"\"Generate and print text samples.\"\"\"\n",
|
| 963 |
+
" mdl.eval()\n",
|
| 964 |
+
" tokens = mdl.sample(num_samples, seq_len, temperature=temperature)\n",
|
| 965 |
" texts = []\n",
|
| 966 |
" for i in range(num_samples):\n",
|
| 967 |
" text = tokenizer.decode(tokens[i].cpu().tolist(), skip_special_tokens=True)\n",
|
| 968 |
" texts.append(text)\n",
|
| 969 |
" print(f\"\\n--- Sample {i+1} ---\")\n",
|
| 970 |
" print(text[:500])\n",
|
| 971 |
+
" mdl.train()\n",
|
| 972 |
" return texts\n",
|
| 973 |
"\n",
|
| 974 |
"\n",
|
|
|
|
| 1873 |
" tokens_processed += batch.numel()\n",
|
| 1874 |
"\n",
|
| 1875 |
" with autocast('cuda', dtype=torch.float16):\n",
|
| 1876 |
+
" # Noise + mask on this batch\n",
|
| 1877 |
+
" B, L = batch.shape\n",
|
| 1878 |
+
" t = model_unwrapped.noise_schedule.sample_t(B, batch.device)\n",
|
| 1879 |
+
" z_t, mask = model_unwrapped.noise_schedule.forward_process(batch, t, config.mask_token_id)\n",
|
| 1880 |
+
"\n",
|
| 1881 |
+
" # Forward pass through DataParallel (this splits across GPUs)\n",
|
| 1882 |
+
" hidden = model_dp(z_t, t) # [B, L, D] \u2014 uses forward_hidden via DataParallel\n",
|
| 1883 |
+
"\n",
|
| 1884 |
+
" # Loss computation (cheap, single GPU is fine)\n",
|
| 1885 |
+
" masked_hidden = hidden[mask]\n",
|
| 1886 |
+
" masked_targets = batch[mask]\n",
|
| 1887 |
+
"\n",
|
| 1888 |
+
" if masked_hidden.shape[0] > 0:\n",
|
| 1889 |
+
" masked_logits = F.linear(masked_hidden, model_unwrapped.output_proj.weight)\n",
|
| 1890 |
+
" masked_logits[:, config.mask_token_id] = -1e9\n",
|
| 1891 |
+
" ce_loss = F.cross_entropy(masked_logits, masked_targets, reduction='none')\n",
|
| 1892 |
+
" weight = model_unwrapped.noise_schedule.loss_weight(t)\n",
|
| 1893 |
+
" weight_expanded = weight[:, None].expand(B, L)[mask]\n",
|
| 1894 |
+
" result_loss = (ce_loss * weight_expanded).mean()\n",
|
| 1895 |
+
"\n",
|
| 1896 |
+
" with torch.no_grad():\n",
|
| 1897 |
+
" preds = masked_logits.argmax(dim=-1)\n",
|
| 1898 |
+
" result_acc = (preds == masked_targets).float().mean().item()\n",
|
| 1899 |
+
" else:\n",
|
| 1900 |
+
" result_loss = torch.tensor(0.0, device=batch.device)\n",
|
| 1901 |
+
" result_acc = 1.0\n",
|
| 1902 |
+
"\n",
|
| 1903 |
+
" loss = result_loss / config.grad_accum_steps\n",
|
| 1904 |
"\n",
|
| 1905 |
" scaler.scale(loss).backward()\n",
|
| 1906 |
+
" step_loss += result_loss.item() / config.grad_accum_steps\n",
|
| 1907 |
+
" step_acc += result_acc / config.grad_accum_steps\n",
|
| 1908 |
"\n",
|
| 1909 |
" # Gradient clipping and optimizer step\n",
|
| 1910 |
" scaler.unscale_(optimizer)\n",
|