Spaces:
Sleeping
Sleeping
Commit ·
35b065e
1
Parent(s): e8aab00
update notebook
Browse files- generator.ipynb +48 -170
generator.ipynb
CHANGED
|
@@ -645,201 +645,79 @@
|
|
| 645 |
},
|
| 646 |
{
|
| 647 |
"cell_type": "code",
|
| 648 |
-
"execution_count":
|
| 649 |
-
"id": "
|
| 650 |
"metadata": {},
|
| 651 |
"outputs": [
|
| 652 |
{
|
| 653 |
"name": "stdout",
|
| 654 |
"output_type": "stream",
|
| 655 |
"text": [
|
| 656 |
-
"
|
| 657 |
-
"
|
| 658 |
-
"after applying the Transpose fix, the results below will likely\n",
|
| 659 |
-
"be poor/incomplete because the model hasn't updated its weights\n",
|
| 660 |
-
"correctly yet.\n",
|
| 661 |
-
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
|
| 662 |
-
"\n",
|
| 663 |
-
"Running generation tests on validation set (True Greedy Decoding):\n",
|
| 664 |
-
"\n",
|
| 665 |
-
"Example 0:\n",
|
| 666 |
-
"Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
|
| 667 |
" ;\n",
|
| 668 |
"\n",
|
| 669 |
-
"
|
| 670 |
-
"
|
| 671 |
-
"------------------------------------------------------------\n",
|
| 672 |
-
"Example 10:\n",
|
| 673 |
-
"Input: [CLS:reduction] for (i = 1; i < (500 - 1); i++)\n",
|
| 674 |
-
"{\n",
|
| 675 |
-
" iIndex = i * dim2;\n",
|
| 676 |
-
" jIndex = 0;\n",
|
| 677 |
-
" for (j = 1; j < (500 - 1); j++)\n",
|
| 678 |
-
" {\n",
|
| 679 |
-
" jIndex += 500;\n",
|
| 680 |
-
" for (k = 1; k < (500 - 1); k++)\n",
|
| 681 |
-
" {\n",
|
| 682 |
-
" index = (iIndex + jIndex) + k;\n",
|
| 683 |
-
" compute_it = old[index] * need;\n",
|
| 684 |
-
" aggregate += compute_it / gimmie;\n",
|
| 685 |
-
" accumulator = 0;\n",
|
| 686 |
-
" long subsum1 = 0;\n",
|
| 687 |
-
" long subsum2 = 0;\n",
|
| 688 |
-
" long subsum3 = 0;\n",
|
| 689 |
-
" for (z = 0; z < 27; z += 3)\n",
|
| 690 |
-
" {\n",
|
| 691 |
-
" subsum1 += old[index + arr[z]];\n",
|
| 692 |
-
" subsum2 += old[index + arr[z + 1]];\n",
|
| 693 |
-
" subsum3 += old[index + arr[z + 2]];\n",
|
| 694 |
-
" }\n",
|
| 695 |
-
"\n",
|
| 696 |
-
" accumulator += (subsum1 + subsum2) + subsum3;\n",
|
| 697 |
-
" long value = accumulator / 27;\n",
|
| 698 |
-
" int par = value / 100;\n",
|
| 699 |
-
" a0 += ((unsigned) par) >> 31;\n",
|
| 700 |
-
" a0 += !(par ^ 0);\n",
|
| 701 |
-
" a1 += !(par ^ 1);\n",
|
| 702 |
-
" a2 += !(par ^ 2);\n",
|
| 703 |
-
" a3 += !(par ^ 3);\n",
|
| 704 |
-
" a4 += !(par ^ 4);\n",
|
| 705 |
-
" a5 += !(par ^ 5);\n",
|
| 706 |
-
" a6 += !(par ^ 6);\n",
|
| 707 |
-
" a7 += !(par ^ 7);\n",
|
| 708 |
-
" a8 += !(par ^ 8);\n",
|
| 709 |
-
" int64_t tmp = ((int64_t) par) - 9;\n",
|
| 710 |
-
" a9 += (tmp >> 63) + 1;\n",
|
| 711 |
-
" new[index] = value;\n",
|
| 712 |
-
" }\n",
|
| 713 |
-
"\n",
|
| 714 |
-
" }\n",
|
| 715 |
-
"\n",
|
| 716 |
-
"}\n",
|
| 717 |
-
"\n",
|
| 718 |
-
"Target: omp parallel for private(j, k, z, accumulator, jIndex, index, iIndex, compute_it) reduction(+: aggregate, a0,a1,a2,a3,a4,a5,a6,a7,a8,a9)\n",
|
| 719 |
-
"Prediction: omp parallel for reduction(+:data,,,,,,,,,,,,,,\n",
|
| 720 |
-
"------------------------------------------------------------\n",
|
| 721 |
-
"Example 20:\n",
|
| 722 |
-
"Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
|
| 723 |
-
" ;\n",
|
| 724 |
-
"\n",
|
| 725 |
-
"Target: omp parallel for simd firstprivate(, )\n",
|
| 726 |
-
"Prediction: omp parallel for shared(,k,,,,,,,,,,,,,pr) shared(L,,,,,,,,,,,,,,,,\n",
|
| 727 |
-
"------------------------------------------------------------\n",
|
| 728 |
-
"Example 30:\n",
|
| 729 |
-
"Input: [CLS:parallel_for] for (i = 0; i < n; i++)\n",
|
| 730 |
-
"{\n",
|
| 731 |
-
" x[i] = 1.0;\n",
|
| 732 |
-
" y[i] = 2.0;\n",
|
| 733 |
-
"}\n",
|
| 734 |
-
"\n",
|
| 735 |
-
"Target: omp parallel for private(i)\n",
|
| 736 |
-
"Prediction: omp parallel for shared(gen,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
|
| 737 |
-
"------------------------------------------------------------\n"
|
| 738 |
]
|
| 739 |
}
|
| 740 |
],
|
| 741 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
"model.eval()\n",
|
|
|
|
| 743 |
"\n",
|
| 744 |
-
"
|
| 745 |
-
"
|
| 746 |
-
"
|
| 747 |
-
"
|
| 748 |
-
" \"\"\"\n",
|
| 749 |
" model.eval()\n",
|
| 750 |
-
" \n",
|
| 751 |
-
"
|
| 752 |
-
"
|
| 753 |
-
"
|
| 754 |
-
"
|
| 755 |
-
"
|
| 756 |
" with torch.no_grad():\n",
|
| 757 |
-
"
|
| 758 |
-
"
|
| 759 |
-
"
|
| 760 |
-
" # Create mask (same logic as in Generator.forward)\n",
|
| 761 |
-
" max_src_len = encoder_outputs.shape[1]\n",
|
| 762 |
-
" mask = torch.arange(max_src_len, device=device).unsqueeze(0) < src_len.unsqueeze(1)\n",
|
| 763 |
-
" mask = mask.float()\n",
|
| 764 |
-
" \n",
|
| 765 |
-
" # Project hidden/cell states from Encoder to Decoder size\n",
|
| 766 |
-
" # Reshape to [num_layers, 2, batch, hidden] to combine bidirectional states\n",
|
| 767 |
" hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
|
| 768 |
" hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)\n",
|
| 769 |
" hidden = model.hidden_projection(hidden)\n",
|
| 770 |
-
"
|
| 771 |
" cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
|
| 772 |
" cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)\n",
|
| 773 |
" cell = model.cell_projection(cell)\n",
|
| 774 |
-
" \n",
|
| 775 |
-
" # Start with <SOS>\n",
|
| 776 |
-
" trg_indexes = [tokenizer.char2idx['<SOS>']]\n",
|
| 777 |
-
" \n",
|
| 778 |
-
" for i in range(max_len):\n",
|
| 779 |
-
" trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) # [1]\n",
|
| 780 |
-
" \n",
|
| 781 |
-
" output, hidden, cell, _ = model.decoder(\n",
|
| 782 |
-
" trg_tensor, hidden, cell, encoder_outputs, mask\n",
|
| 783 |
-
" )\n",
|
| 784 |
-
" \n",
|
| 785 |
-
" # Greedy prediction: take token with highest probability\n",
|
| 786 |
-
" pred_token = output.argmax(1).item()\n",
|
| 787 |
-
" trg_indexes.append(pred_token)\n",
|
| 788 |
-
" \n",
|
| 789 |
-
" if pred_token == tokenizer.char2idx['<EOS>']:\n",
|
| 790 |
-
" break\n",
|
| 791 |
-
" \n",
|
| 792 |
-
" # Decode integers back to string\n",
|
| 793 |
-
" return tokenizer.decode(trg_indexes)\n",
|
| 794 |
-
"\n",
|
| 795 |
-
"# ---------------------------------------------------------\n",
|
| 796 |
-
"print(\"!\"*60)\n",
|
| 797 |
-
"print(\"IMPORTANT: If you haven't re-run the TRAINING loop (Cell 9)\")\n",
|
| 798 |
-
"print(\"after applying the Transpose fix, the results below will likely\")\n",
|
| 799 |
-
"print(\"be poor/incomplete because the model hasn't updated its weights\")\n",
|
| 800 |
-
"print(\"correctly yet.\")\n",
|
| 801 |
-
"print(\"!\"*60 + \"\\n\")\n",
|
| 802 |
"\n",
|
| 803 |
-
"
|
| 804 |
-
"
|
| 805 |
-
"
|
| 806 |
-
"
|
| 807 |
-
"\n",
|
| 808 |
-
"
|
| 809 |
-
"
|
| 810 |
-
"
|
| 811 |
-
"
|
| 812 |
-
"
|
| 813 |
-
" \n",
|
| 814 |
-
" print(f\"Example {i}:\")\n",
|
| 815 |
-
" print(f\"Input: {input_text}\")\n",
|
| 816 |
-
" print(f\"Target: {target_text}\")\n",
|
| 817 |
-
" print(f\"Prediction: {prediction}\")\n",
|
| 818 |
-
" print(\"-\" * 60)"
|
| 819 |
-
]
|
| 820 |
-
},
|
| 821 |
-
{
|
| 822 |
-
"cell_type": "code",
|
| 823 |
-
"execution_count": null,
|
| 824 |
-
"id": "85bd9571",
|
| 825 |
-
"metadata": {},
|
| 826 |
-
"outputs": [],
|
| 827 |
-
"source": [
|
| 828 |
-
"# ---------------------------------------------------------\n",
|
| 829 |
-
"# RUN THIS CELL ONLY IF YOU WANT TO RESET TRAINING\n",
|
| 830 |
-
"# This initializes the model weights from scratch. \n",
|
| 831 |
-
"# Run this, and then run the TRAINING LOOP (Cell 9) again.\n",
|
| 832 |
-
"# ---------------------------------------------------------\n",
|
| 833 |
-
"\n",
|
| 834 |
-
"print(\"↺ RESETTING MODEL & OPTIMIZER...\")\n",
|
| 835 |
-
"model = Generator(encoder, decoder, device).to(device)\n",
|
| 836 |
-
"model.apply(model._init_weights)\n",
|
| 837 |
"\n",
|
| 838 |
-
"
|
| 839 |
-
"training_history = {'train_loss': [], 'valid_loss': []}\n",
|
| 840 |
-
"best_valid_loss = float('inf')\n",
|
| 841 |
"\n",
|
| 842 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
]
|
| 844 |
}
|
| 845 |
],
|
|
|
|
| 645 |
},
|
| 646 |
{
|
| 647 |
"cell_type": "code",
|
| 648 |
+
"execution_count": null,
|
| 649 |
+
"id": "6d9a8e25",
|
| 650 |
"metadata": {},
|
| 651 |
"outputs": [
|
| 652 |
{
|
| 653 |
"name": "stdout",
|
| 654 |
"output_type": "stream",
|
| 655 |
"text": [
|
| 656 |
+
"Loaded checkpoint from best_model.pth (epoch 14)\n",
|
| 657 |
+
"Sample input (truncated): [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
" ;\n",
|
| 659 |
"\n",
|
| 660 |
+
"Reference pragma: omp target parallel for simd simdlen(4 4)\n",
|
| 661 |
+
"Model prediction: omp parallel for simd lastprivate(\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
]
|
| 663 |
}
|
| 664 |
],
|
| 665 |
"source": [
|
| 666 |
+
"import os\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"checkpoint_path = \"best_model.pth\"\n",
|
| 669 |
+
"if not os.path.exists(checkpoint_path):\n",
|
| 670 |
+
" raise FileNotFoundError(\"Run training first so 'best_model.pth' exists.\")\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"checkpoint = torch.load(checkpoint_path, map_location=device)\n",
|
| 673 |
+
"model.load_state_dict(checkpoint['model_state_dict'])\n",
|
| 674 |
"model.eval()\n",
|
| 675 |
+
"print(f\"Loaded checkpoint from {checkpoint_path} (epoch {checkpoint.get('epoch', '?')})\")\n",
|
| 676 |
"\n",
|
| 677 |
+
"SOS_IDX = tokenizer.char2idx['<SOS>']\n",
|
| 678 |
+
"EOS_IDX = tokenizer.char2idx['<EOS>']\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"def greedy_generate(code_snippet: str, cls: str = \"parallel\", max_len: int = 80) -> str:\n",
|
| 681 |
+
" \"\"\"Greedy decode a pragma for a single code snippet.\"\"\"\n",
|
| 682 |
" model.eval()\n",
|
| 683 |
+
" text = code_snippet if code_snippet.startswith(\"[CLS:\") else f\"[CLS:{cls}] {code_snippet}\"\n",
|
| 684 |
+
" input_ids = tokenizer.encode(text, max_length=500, add_special_tokens=True)\n",
|
| 685 |
+
" input_len = next((i for i, tok in enumerate(input_ids) if tok == PAD_IDX), len(input_ids))\n",
|
| 686 |
+
" input_tensor = torch.tensor([input_ids], device=device)\n",
|
| 687 |
+
" input_len_tensor = torch.tensor([input_len], device=device)\n",
|
| 688 |
+
"\n",
|
| 689 |
" with torch.no_grad():\n",
|
| 690 |
+
" enc_outs, hidden, cell = model.encoder(input_tensor, input_len_tensor)\n",
|
| 691 |
+
" mask = (torch.arange(enc_outs.size(1), device=device).unsqueeze(0) < input_len_tensor.unsqueeze(1)).float()\n",
|
| 692 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
" hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
|
| 694 |
" hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)\n",
|
| 695 |
" hidden = model.hidden_projection(hidden)\n",
|
| 696 |
+
"\n",
|
| 697 |
" cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
|
| 698 |
" cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)\n",
|
| 699 |
" cell = model.cell_projection(cell)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
"\n",
|
| 701 |
+
" input_token = torch.tensor([SOS_IDX], device=device)\n",
|
| 702 |
+
" generated = []\n",
|
| 703 |
+
" for _ in range(max_len):\n",
|
| 704 |
+
" output, hidden, cell, _ = model.decoder(input_token, hidden, cell, enc_outs, mask)\n",
|
| 705 |
+
" top1 = output.argmax(1)\n",
|
| 706 |
+
" token_id = top1.item()\n",
|
| 707 |
+
" if token_id == EOS_IDX:\n",
|
| 708 |
+
" break\n",
|
| 709 |
+
" generated.append(token_id)\n",
|
| 710 |
+
" input_token = top1\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
"\n",
|
| 712 |
+
" return tokenizer.decode(generated)\n",
|
|
|
|
|
|
|
| 713 |
"\n",
|
| 714 |
+
"# Quick sanity check on a validation example\n",
|
| 715 |
+
"sample_input = val_inputs[0]\n",
|
| 716 |
+
"reference = val_outputs[0]\n",
|
| 717 |
+
"prediction = greedy_generate(sample_input)\n",
|
| 718 |
+
"print(\"Sample input (truncated):\", sample_input[:140] + \"...\" if len(sample_input) > 140 else sample_input)\n",
|
| 719 |
+
"print(\"Reference pragma:\", reference)\n",
|
| 720 |
+
"print(\"Model prediction:\", prediction)\n"
|
| 721 |
]
|
| 722 |
}
|
| 723 |
],
|