mohamedahraf273 commited on
Commit
35b065e
·
1 Parent(s): e8aab00

update notebook

Browse files
Files changed (1) hide show
  1. generator.ipynb +48 -170
generator.ipynb CHANGED
@@ -645,201 +645,79 @@
645
  },
646
  {
647
  "cell_type": "code",
648
- "execution_count": 15,
649
- "id": "a49bb85f",
650
  "metadata": {},
651
  "outputs": [
652
  {
653
  "name": "stdout",
654
  "output_type": "stream",
655
  "text": [
656
- "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
657
- "IMPORTANT: If you haven't re-run the TRAINING loop (Cell 9)\n",
658
- "after applying the Transpose fix, the results below will likely\n",
659
- "be poor/incomplete because the model hasn't updated its weights\n",
660
- "correctly yet.\n",
661
- "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
662
- "\n",
663
- "Running generation tests on validation set (True Greedy Decoding):\n",
664
- "\n",
665
- "Example 0:\n",
666
- "Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
667
  " ;\n",
668
  "\n",
669
- "Target: omp target parallel for simd simdlen(4 4)\n",
670
- "Prediction: omp parallel for shared(,k,,,,,,,,,,,,,pr) shared(L,,,,,,,,,,,,,,,,\n",
671
- "------------------------------------------------------------\n",
672
- "Example 10:\n",
673
- "Input: [CLS:reduction] for (i = 1; i < (500 - 1); i++)\n",
674
- "{\n",
675
- " iIndex = i * dim2;\n",
676
- " jIndex = 0;\n",
677
- " for (j = 1; j < (500 - 1); j++)\n",
678
- " {\n",
679
- " jIndex += 500;\n",
680
- " for (k = 1; k < (500 - 1); k++)\n",
681
- " {\n",
682
- " index = (iIndex + jIndex) + k;\n",
683
- " compute_it = old[index] * need;\n",
684
- " aggregate += compute_it / gimmie;\n",
685
- " accumulator = 0;\n",
686
- " long subsum1 = 0;\n",
687
- " long subsum2 = 0;\n",
688
- " long subsum3 = 0;\n",
689
- " for (z = 0; z < 27; z += 3)\n",
690
- " {\n",
691
- " subsum1 += old[index + arr[z]];\n",
692
- " subsum2 += old[index + arr[z + 1]];\n",
693
- " subsum3 += old[index + arr[z + 2]];\n",
694
- " }\n",
695
- "\n",
696
- " accumulator += (subsum1 + subsum2) + subsum3;\n",
697
- " long value = accumulator / 27;\n",
698
- " int par = value / 100;\n",
699
- " a0 += ((unsigned) par) >> 31;\n",
700
- " a0 += !(par ^ 0);\n",
701
- " a1 += !(par ^ 1);\n",
702
- " a2 += !(par ^ 2);\n",
703
- " a3 += !(par ^ 3);\n",
704
- " a4 += !(par ^ 4);\n",
705
- " a5 += !(par ^ 5);\n",
706
- " a6 += !(par ^ 6);\n",
707
- " a7 += !(par ^ 7);\n",
708
- " a8 += !(par ^ 8);\n",
709
- " int64_t tmp = ((int64_t) par) - 9;\n",
710
- " a9 += (tmp >> 63) + 1;\n",
711
- " new[index] = value;\n",
712
- " }\n",
713
- "\n",
714
- " }\n",
715
- "\n",
716
- "}\n",
717
- "\n",
718
- "Target: omp parallel for private(j, k, z, accumulator, jIndex, index, iIndex, compute_it) reduction(+: aggregate, a0,a1,a2,a3,a4,a5,a6,a7,a8,a9)\n",
719
- "Prediction: omp parallel for reduction(+:data,,,,,,,,,,,,,,\n",
720
- "------------------------------------------------------------\n",
721
- "Example 20:\n",
722
- "Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
723
- " ;\n",
724
- "\n",
725
- "Target: omp parallel for simd firstprivate(, )\n",
726
- "Prediction: omp parallel for shared(,k,,,,,,,,,,,,,pr) shared(L,,,,,,,,,,,,,,,,\n",
727
- "------------------------------------------------------------\n",
728
- "Example 30:\n",
729
- "Input: [CLS:parallel_for] for (i = 0; i < n; i++)\n",
730
- "{\n",
731
- " x[i] = 1.0;\n",
732
- " y[i] = 2.0;\n",
733
- "}\n",
734
- "\n",
735
- "Target: omp parallel for private(i)\n",
736
- "Prediction: omp parallel for shared(gen,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
737
- "------------------------------------------------------------\n"
738
  ]
739
  }
740
  ],
741
  "source": [
 
 
 
 
 
 
 
 
742
  "model.eval()\n",
 
743
  "\n",
744
- "def generate_sentence(model, input_text, tokenizer, max_len=150, device='cuda'):\n",
745
- " \"\"\"\n",
746
- " Greedy decoding function that generates tokens until <EOS> or max_len.\n",
747
- " This mimics the model's forward pass but allows dynamic length generation.\n",
748
- " \"\"\"\n",
749
  " model.eval()\n",
750
- " \n",
751
- " # Tokenize input\n",
752
- " input_ids = tokenizer.encode(input_text, max_length=500, add_special_tokens=True)\n",
753
- " src_tensor = torch.LongTensor(input_ids).unsqueeze(0).to(device) # [1, src_len]\n",
754
- " src_len = torch.LongTensor([len(input_ids)]).to(device) # [1]\n",
755
- " \n",
756
  " with torch.no_grad():\n",
757
- " # Encode\n",
758
- " encoder_outputs, hidden, cell = model.encoder(src_tensor, src_len)\n",
759
- " \n",
760
- " # Create mask (same logic as in Generator.forward)\n",
761
- " max_src_len = encoder_outputs.shape[1]\n",
762
- " mask = torch.arange(max_src_len, device=device).unsqueeze(0) < src_len.unsqueeze(1)\n",
763
- " mask = mask.float()\n",
764
- " \n",
765
- " # Project hidden/cell states from Encoder to Decoder size\n",
766
- " # Reshape to [num_layers, 2, batch, hidden] to combine bidirectional states\n",
767
  " hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
768
  " hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)\n",
769
  " hidden = model.hidden_projection(hidden)\n",
770
- " \n",
771
  " cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
772
  " cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)\n",
773
  " cell = model.cell_projection(cell)\n",
774
- " \n",
775
- " # Start with <SOS>\n",
776
- " trg_indexes = [tokenizer.char2idx['<SOS>']]\n",
777
- " \n",
778
- " for i in range(max_len):\n",
779
- " trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) # [1]\n",
780
- " \n",
781
- " output, hidden, cell, _ = model.decoder(\n",
782
- " trg_tensor, hidden, cell, encoder_outputs, mask\n",
783
- " )\n",
784
- " \n",
785
- " # Greedy prediction: take token with highest probability\n",
786
- " pred_token = output.argmax(1).item()\n",
787
- " trg_indexes.append(pred_token)\n",
788
- " \n",
789
- " if pred_token == tokenizer.char2idx['<EOS>']:\n",
790
- " break\n",
791
- " \n",
792
- " # Decode integers back to string\n",
793
- " return tokenizer.decode(trg_indexes)\n",
794
- "\n",
795
- "# ---------------------------------------------------------\n",
796
- "print(\"!\"*60)\n",
797
- "print(\"IMPORTANT: If you haven't re-run the TRAINING loop (Cell 9)\")\n",
798
- "print(\"after applying the Transpose fix, the results below will likely\")\n",
799
- "print(\"be poor/incomplete because the model hasn't updated its weights\")\n",
800
- "print(\"correctly yet.\")\n",
801
- "print(\"!\"*60 + \"\\n\")\n",
802
  "\n",
803
- "print(\"Running generation tests on validation set (True Greedy Decoding):\\n\")\n",
804
- "test_indices = [0, 10, 20, 30]\n",
805
- "# Ensure indices are within bounds\n",
806
- "test_indices = [i for i in test_indices if i < len(val_inputs)]\n",
807
- "\n",
808
- "for i in test_indices:\n",
809
- " input_text = val_inputs[i]\n",
810
- " target_text = val_outputs[i]\n",
811
- " \n",
812
- " prediction = generate_sentence(model, input_text, tokenizer, device=device)\n",
813
- " \n",
814
- " print(f\"Example {i}:\")\n",
815
- " print(f\"Input: {input_text}\")\n",
816
- " print(f\"Target: {target_text}\")\n",
817
- " print(f\"Prediction: {prediction}\")\n",
818
- " print(\"-\" * 60)"
819
- ]
820
- },
821
- {
822
- "cell_type": "code",
823
- "execution_count": null,
824
- "id": "85bd9571",
825
- "metadata": {},
826
- "outputs": [],
827
- "source": [
828
- "# ---------------------------------------------------------\n",
829
- "# RUN THIS CELL ONLY IF YOU WANT TO RESET TRAINING\n",
830
- "# This initializes the model weights from scratch. \n",
831
- "# Run this, and then run the TRAINING LOOP (Cell 9) again.\n",
832
- "# ---------------------------------------------------------\n",
833
- "\n",
834
- "print(\"↺ RESETTING MODEL & OPTIMIZER...\")\n",
835
- "model = Generator(encoder, decoder, device).to(device)\n",
836
- "model.apply(model._init_weights)\n",
837
  "\n",
838
- "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
839
- "training_history = {'train_loss': [], 'valid_loss': []}\n",
840
- "best_valid_loss = float('inf')\n",
841
  "\n",
842
- "print(\"✓ Model reset. Now scroll up and run the TRAINING LOOP again.\")"
 
 
 
 
 
 
843
  ]
844
  }
845
  ],
 
645
  },
646
  {
647
  "cell_type": "code",
648
+ "execution_count": null,
649
+ "id": "6d9a8e25",
650
  "metadata": {},
651
  "outputs": [
652
  {
653
  "name": "stdout",
654
  "output_type": "stream",
655
  "text": [
656
+ "Loaded checkpoint from best_model.pth (epoch 14)\n",
657
+ "Sample input (truncated): [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
 
 
 
 
 
 
 
 
 
658
  " ;\n",
659
  "\n",
660
+ "Reference pragma: omp target parallel for simd simdlen(4 4)\n",
661
+ "Model prediction: omp parallel for simd lastprivate(\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  ]
663
  }
664
  ],
665
  "source": [
666
+ "import os\n",
667
+ "\n",
668
+ "checkpoint_path = \"best_model.pth\"\n",
669
+ "if not os.path.exists(checkpoint_path):\n",
670
+ " raise FileNotFoundError(\"Run training first so 'best_model.pth' exists.\")\n",
671
+ "\n",
672
+ "checkpoint = torch.load(checkpoint_path, map_location=device)\n",
673
+ "model.load_state_dict(checkpoint['model_state_dict'])\n",
674
  "model.eval()\n",
675
+ "print(f\"Loaded checkpoint from {checkpoint_path} (epoch {checkpoint.get('epoch', '?')})\")\n",
676
  "\n",
677
+ "SOS_IDX = tokenizer.char2idx['<SOS>']\n",
678
+ "EOS_IDX = tokenizer.char2idx['<EOS>']\n",
679
+ "\n",
680
+ "def greedy_generate(code_snippet: str, cls: str = \"parallel\", max_len: int = 80) -> str:\n",
681
+ " \"\"\"Greedy decode a pragma for a single code snippet.\"\"\"\n",
682
  " model.eval()\n",
683
+ " text = code_snippet if code_snippet.startswith(\"[CLS:\") else f\"[CLS:{cls}] {code_snippet}\"\n",
684
+ " input_ids = tokenizer.encode(text, max_length=500, add_special_tokens=True)\n",
685
+ " input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids))\n",
686
+ " input_tensor = torch.tensor([input_ids], device=device)\n",
687
+ " input_len_tensor = torch.tensor([input_len], device=device)\n",
688
+ "\n",
689
  " with torch.no_grad():\n",
690
+ " enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor)\n",
691
+ " mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float()\n",
692
+ "\n",
 
 
 
 
 
 
 
693
  " hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
694
  " hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)\n",
695
  " hidden = model.hidden_projection(hidden)\n",
696
+ "\n",
697
  " cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
698
  " cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)\n",
699
  " cell = model.cell_projection(cell)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  "\n",
701
+ " input_token = torch.tensor([SOS_IDX], device=device)\n",
702
+ " generated = []\n",
703
+ " for _ in range(max_len):\n",
704
+ " output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask)\n",
705
+ " top1 = output.argmax(1)\n",
706
+ " token_id = top1.item()\n",
707
+ " if token_id == EOS_IDX:\n",
708
+ " break\n",
709
+ " generated.append(token_id)\n",
710
+ " input_token = top1\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  "\n",
712
+ " return tokenizer.decode(generated)\n",
 
 
713
  "\n",
714
+ "# Quick sanity check on a validation example\n",
715
+ "sample_input = val_inputs[0]\n",
716
+ "reference = val_outputs[0]\n",
717
+ "prediction = greedy_generate(sample_input)\n",
718
+ "print(\"Sample input (truncated):\", sample_input[:140] + \"...\" if len(sample_input) > 140 else sample_input)\n",
719
+ "print(\"Reference pragma:\", reference)\n",
720
+ "print(\"Model prediction:\", prediction)\n"
721
  ]
722
  }
723
  ],