chipling commited on
Commit
1c89b07
·
verified ·
1 Parent(s): 5f202d0

Upload main.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.ipynb +46 -14
main.ipynb CHANGED
@@ -521,7 +521,13 @@
521
  " return self.final_norm(x)\n",
522
  "\n",
523
  " def forward(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
524
- " \"\"\"Full forward pass returning logits [B, L, V].\"\"\"\n",
 
 
 
 
 
 
525
  " hidden = self.forward_hidden(z_t, t)\n",
526
  " logits = self.output_proj(hidden)\n",
527
  " logits[:, :, self.config.mask_token_id] = -1e9\n",
@@ -600,7 +606,7 @@
600
  " alpha_next = self.noise_schedule.alpha(t_next)\n",
601
  "\n",
602
  " t_batch = torch.full((batch_size,), t_now.item(), device=device)\n",
603
- " logits = self.forward(x, t_batch)\n",
604
  " probs = F.softmax(logits / temperature, dim=-1)\n",
605
  "\n",
606
  " unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
@@ -616,7 +622,7 @@
616
  " is_masked = (x == self.config.mask_token_id)\n",
617
  " if is_masked.any():\n",
618
  " t_batch = torch.full((batch_size,), 1e-5, device=device)\n",
619
- " logits = self.forward(x, t_batch)\n",
620
  " probs = F.softmax(logits / temperature, dim=-1)\n",
621
  " flat_probs = probs.reshape(-1, self.config.vocab_size)\n",
622
  " sampled = torch.multinomial(flat_probs, 1).reshape(batch_size, seq_len)\n",
@@ -633,12 +639,12 @@
633
  "print(f\"Unique parameters (weight tying): {unique_params / 1e6:.1f}M\")\n",
634
  "\n",
635
  "# Multi-GPU support (Kaggle T4 x2)\n",
 
636
  "if torch.cuda.device_count() > 1:\n",
637
  " print(f\"\\nUsing {torch.cuda.device_count()} GPUs with DataParallel!\")\n",
638
- " model = nn.DataParallel(model)\n",
639
- " model_unwrapped = model.module\n",
640
  "else:\n",
641
- " model_unwrapped = model\n",
642
  "\n",
643
  "# Quick memory test\n",
644
  "with torch.no_grad():\n",
@@ -952,17 +958,17 @@
952
  "\n",
953
  "\n",
954
  "@torch.no_grad()\n",
955
- "def generate_samples(model_unwrapped, tokenizer, num_samples=4, seq_len=128, temperature=0.8):\n",
956
  " \"\"\"Generate and print text samples.\"\"\"\n",
957
- " model.eval()\n",
958
- " tokens = model.sample(num_samples, seq_len, temperature=temperature)\n",
959
  " texts = []\n",
960
  " for i in range(num_samples):\n",
961
  " text = tokenizer.decode(tokens[i].cpu().tolist(), skip_special_tokens=True)\n",
962
  " texts.append(text)\n",
963
  " print(f\"\\n--- Sample {i+1} ---\")\n",
964
  " print(text[:500])\n",
965
- " model.train()\n",
966
  " return texts\n",
967
  "\n",
968
  "\n",
@@ -1867,12 +1873,38 @@
1867
  " tokens_processed += batch.numel()\n",
1868
  "\n",
1869
  " with autocast('cuda', dtype=torch.float16):\n",
1870
- " result = model_unwrapped.compute_loss(batch)\n",
1871
- " loss = result['loss'] / config.grad_accum_steps\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1872
  "\n",
1873
  " scaler.scale(loss).backward()\n",
1874
- " step_loss += result['loss'].item() / config.grad_accum_steps\n",
1875
- " step_acc += result['accuracy'].item() / config.grad_accum_steps\n",
1876
  "\n",
1877
  " # Gradient clipping and optimizer step\n",
1878
  " scaler.unscale_(optimizer)\n",
 
521
  " return self.final_norm(x)\n",
522
  "\n",
523
  " def forward(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
524
+ " \"\"\"Forward pass returning hidden states [B, L, D].\n",
525
+ " Used by DataParallel \u2014 logit projection done outside for memory efficiency.\n",
526
+ " For full logits (sampling), use forward_full().\"\"\"\n",
527
+ " return self.forward_hidden(z_t, t)\n",
528
+ "\n",
529
+ " def forward_full(self, z_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n",
530
+ " \"\"\"Full forward pass returning logits [B, L, V]. Used for sampling.\"\"\"\n",
531
  " hidden = self.forward_hidden(z_t, t)\n",
532
  " logits = self.output_proj(hidden)\n",
533
  " logits[:, :, self.config.mask_token_id] = -1e9\n",
 
606
  " alpha_next = self.noise_schedule.alpha(t_next)\n",
607
  "\n",
608
  " t_batch = torch.full((batch_size,), t_now.item(), device=device)\n",
609
+ " logits = self.forward_full(x, t_batch)\n",
610
  " probs = F.softmax(logits / temperature, dim=-1)\n",
611
  "\n",
612
  " unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
 
622
  " is_masked = (x == self.config.mask_token_id)\n",
623
  " if is_masked.any():\n",
624
  " t_batch = torch.full((batch_size,), 1e-5, device=device)\n",
625
+ " logits = self.forward_full(x, t_batch)\n",
626
  " probs = F.softmax(logits / temperature, dim=-1)\n",
627
  " flat_probs = probs.reshape(-1, self.config.vocab_size)\n",
628
  " sampled = torch.multinomial(flat_probs, 1).reshape(batch_size, seq_len)\n",
 
639
  "print(f\"Unique parameters (weight tying): {unique_params / 1e6:.1f}M\")\n",
640
  "\n",
641
  "# Multi-GPU support (Kaggle T4 x2)\n",
642
+ "model_unwrapped = model\n",
643
  "if torch.cuda.device_count() > 1:\n",
644
  " print(f\"\\nUsing {torch.cuda.device_count()} GPUs with DataParallel!\")\n",
645
+ " model_dp = nn.DataParallel(model)\n",
 
646
  "else:\n",
647
+ " model_dp = model\n",
648
  "\n",
649
  "# Quick memory test\n",
650
  "with torch.no_grad():\n",
 
958
  "\n",
959
  "\n",
960
  "@torch.no_grad()\n",
961
+ "def generate_samples(mdl, tokenizer, num_samples=4, seq_len=128, temperature=0.8):\n",
962
  " \"\"\"Generate and print text samples.\"\"\"\n",
963
+ " mdl.eval()\n",
964
+ " tokens = mdl.sample(num_samples, seq_len, temperature=temperature)\n",
965
  " texts = []\n",
966
  " for i in range(num_samples):\n",
967
  " text = tokenizer.decode(tokens[i].cpu().tolist(), skip_special_tokens=True)\n",
968
  " texts.append(text)\n",
969
  " print(f\"\\n--- Sample {i+1} ---\")\n",
970
  " print(text[:500])\n",
971
+ " mdl.train()\n",
972
  " return texts\n",
973
  "\n",
974
  "\n",
 
1873
  " tokens_processed += batch.numel()\n",
1874
  "\n",
1875
  " with autocast('cuda', dtype=torch.float16):\n",
1876
+ " # Noise + mask on this batch\n",
1877
+ " B, L = batch.shape\n",
1878
+ " t = model_unwrapped.noise_schedule.sample_t(B, batch.device)\n",
1879
+ " z_t, mask = model_unwrapped.noise_schedule.forward_process(batch, t, config.mask_token_id)\n",
1880
+ "\n",
1881
+ " # Forward pass through DataParallel (this splits across GPUs)\n",
1882
+ " hidden = model_dp(z_t, t) # [B, L, D] \u2014 uses forward_hidden via DataParallel\n",
1883
+ "\n",
1884
+ " # Loss computation (cheap, single GPU is fine)\n",
1885
+ " masked_hidden = hidden[mask]\n",
1886
+ " masked_targets = batch[mask]\n",
1887
+ "\n",
1888
+ " if masked_hidden.shape[0] > 0:\n",
1889
+ " masked_logits = F.linear(masked_hidden, model_unwrapped.output_proj.weight)\n",
1890
+ " masked_logits[:, config.mask_token_id] = -1e9\n",
1891
+ " ce_loss = F.cross_entropy(masked_logits, masked_targets, reduction='none')\n",
1892
+ " weight = model_unwrapped.noise_schedule.loss_weight(t)\n",
1893
+ " weight_expanded = weight[:, None].expand(B, L)[mask]\n",
1894
+ " result_loss = (ce_loss * weight_expanded).mean()\n",
1895
+ "\n",
1896
+ " with torch.no_grad():\n",
1897
+ " preds = masked_logits.argmax(dim=-1)\n",
1898
+ " result_acc = (preds == masked_targets).float().mean().item()\n",
1899
+ " else:\n",
1900
+ " result_loss = torch.tensor(0.0, device=batch.device)\n",
1901
+ " result_acc = 1.0\n",
1902
+ "\n",
1903
+ " loss = result_loss / config.grad_accum_steps\n",
1904
  "\n",
1905
  " scaler.scale(loss).backward()\n",
1906
+ " step_loss += result_loss.item() / config.grad_accum_steps\n",
1907
+ " step_acc += result_acc / config.grad_accum_steps\n",
1908
  "\n",
1909
  " # Gradient clipping and optimizer step\n",
1910
  " scaler.unscale_(optimizer)\n",