mohamedahraf273 commited on
Commit
3e4a1d2
·
1 Parent(s): 27130aa

update gen

Browse files
Files changed (1) hide show
  1. generator.ipynb +34 -331
generator.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 6,
6
  "id": "bae751d8",
7
  "metadata": {},
8
  "outputs": [],
@@ -26,30 +26,10 @@
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": 7,
30
  "id": "c0e30f61",
31
  "metadata": {},
32
- "outputs": [
33
- {
34
- "name": "stdout",
35
- "output_type": "stream",
36
- "text": [
37
- "BPE Tokenizer loaded from tokenizer.json\n",
38
- " - Vocab size: 8002\n",
39
- " - BPE merges: 7888\n"
40
- ]
41
- },
42
- {
43
- "data": {
44
- "text/plain": [
45
- "<tokenizer.Tokenizer at 0x7d2bbbafcb90>"
46
- ]
47
- },
48
- "execution_count": 7,
49
- "metadata": {},
50
- "output_type": "execute_result"
51
- }
52
- ],
53
  "source": [
54
  "tokenizer = Tokenizer(vocab_size=8000)\n",
55
  "tokenizer.load(\"tokenizer.json\")"
@@ -57,28 +37,10 @@
57
  },
58
  {
59
  "cell_type": "code",
60
- "execution_count": 8,
61
  "id": "db130c45",
62
  "metadata": {},
63
- "outputs": [
64
- {
65
- "name": "stdout",
66
- "output_type": "stream",
67
- "text": [
68
- "Training samples: 15671\n",
69
- "Validation samples: 1684\n",
70
- "\n",
71
- "Sample input (first 70 chars):\n",
72
- "[CLS:parallel_for] for (int ix = 1; ix < (N + 1); ix++)\n",
73
- "{\n",
74
- " forces[ix] = forces[ix] * force_retention;\n",
75
- "}\n",
76
- "\n",
77
- "Sample output:\n",
78
- "omp parallel for\n"
79
- ]
80
- }
81
- ],
82
  "source": [
83
  "train_inputs, train_outputs = [], []\n",
84
  "val_inputs, val_outputs = [], []\n",
@@ -134,34 +96,21 @@
134
  },
135
  {
136
  "cell_type": "code",
137
- "execution_count": 9,
138
  "id": "d5747915",
139
  "metadata": {},
140
- "outputs": [
141
- {
142
- "name": "stdout",
143
- "output_type": "stream",
144
- "text": [
145
- "\n",
146
- "Dataset shapes:\n",
147
- " Train: 15671 samples\n",
148
- " Val: 1684 samples\n",
149
- " Sample input tensor shape: torch.Size([500])\n",
150
- " Sample output tensor shape: torch.Size([100])\n"
151
- ]
152
- }
153
- ],
154
  "source": [
155
  "train_dataset = OpenMPDataset(\n",
156
  " train_inputs, train_outputs, tokenizer,\n",
157
- " max_input_len=500,\n",
158
- " max_output_len=100\n",
159
  ")\n",
160
  "\n",
161
  "val_dataset = OpenMPDataset(\n",
162
  " val_inputs, val_outputs, tokenizer,\n",
163
- " max_input_len=500,\n",
164
- " max_output_len=100\n",
165
  ")\n",
166
  "\n",
167
  "print(f\"\\nDataset shapes:\")\n",
@@ -173,38 +122,21 @@
173
  },
174
  {
175
  "cell_type": "code",
176
- "execution_count": 10,
177
  "id": "5252d457",
178
  "metadata": {},
179
- "outputs": [
180
- {
181
- "name": "stdout",
182
- "output_type": "stream",
183
- "text": [
184
- "\n",
185
- "✓ Dataloaders ready!\n",
186
- " Train batches: 490\n",
187
- " Val batches: 53\n",
188
- "\n",
189
- "Sample batch structure:\n",
190
- " input shape: torch.Size([32, 500])\n",
191
- " output shape: torch.Size([32, 100])\n",
192
- " input_len shape: torch.Size([32])\n",
193
- " First sample input_len: 12\n"
194
- ]
195
- }
196
- ],
197
  "source": [
198
  "train_loader = DataLoader(\n",
199
  " train_dataset,\n",
200
- " batch_size=32,\n",
201
  " shuffle=True,\n",
202
  " pin_memory=True\n",
203
  ")\n",
204
  "\n",
205
  "val_loader = DataLoader(\n",
206
  " val_dataset,\n",
207
- " batch_size=32,\n",
208
  " shuffle=False,\n",
209
  " pin_memory=True\n",
210
  ")\n",
@@ -223,49 +155,17 @@
223
  },
224
  {
225
  "cell_type": "code",
226
- "execution_count": 11,
227
  "id": "11631bed",
228
  "metadata": {},
229
- "outputs": [
230
- {
231
- "name": "stdout",
232
- "output_type": "stream",
233
- "text": [
234
- "Model architecture:\n",
235
- "Generator(\n",
236
- " (encoder): Encoder(\n",
237
- " (embedding): Embedding(8002, 128, padding_idx=0)\n",
238
- " (lstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)\n",
239
- " (dropout): Dropout(p=0.3, inplace=False)\n",
240
- " (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
241
- " )\n",
242
- " (decoder): Decoder(\n",
243
- " (attention): BahdanauAttention(\n",
244
- " (W1): Linear(in_features=512, out_features=256, bias=True)\n",
245
- " (W2): Linear(in_features=256, out_features=256, bias=True)\n",
246
- " (V): Linear(in_features=256, out_features=1, bias=True)\n",
247
- " )\n",
248
- " (embedding): Embedding(8002, 128, padding_idx=0)\n",
249
- " (lstm): LSTM(640, 256, num_layers=2, batch_first=True, dropout=0.3)\n",
250
- " (fc_out): Linear(in_features=896, out_features=8002, bias=True)\n",
251
- " (dropout): Dropout(p=0.3, inplace=False)\n",
252
- " (layer_norm): LayerNorm((896,), eps=1e-05, elementwise_affine=True)\n",
253
- " )\n",
254
- " (hidden_projection): Linear(in_features=512, out_features=256, bias=True)\n",
255
- " (cell_projection): Linear(in_features=512, out_features=256, bias=True)\n",
256
- ")\n",
257
- "\n",
258
- "Total parameters: 13,502,531\n"
259
- ]
260
- }
261
- ],
262
  "source": [
263
  "\n",
264
  "VOCAB_SIZE = tokenizer.vocab_size\n",
265
  "EMBED_SIZE = 128\n",
266
  "HIDDEN_SIZE = 256\n",
267
- "NUM_LAYERS = 2\n",
268
- "DROPOUT = 0.3\n",
269
  "\n",
270
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
271
  "\n",
@@ -282,7 +182,7 @@
282
  },
283
  {
284
  "cell_type": "code",
285
- "execution_count": 12,
286
  "id": "2d3125a6",
287
  "metadata": {},
288
  "outputs": [],
@@ -300,7 +200,7 @@
300
  },
301
  {
302
  "cell_type": "code",
303
- "execution_count": 13,
304
  "id": "794c40e7",
305
  "metadata": {},
306
  "outputs": [],
@@ -357,198 +257,10 @@
357
  },
358
  {
359
  "cell_type": "code",
360
- "execution_count": 14,
361
  "id": "d4bb0e92",
362
  "metadata": {},
363
- "outputs": [
364
- {
365
- "name": "stderr",
366
- "output_type": "stream",
367
- "text": [
368
- " \r"
369
- ]
370
- },
371
- {
372
- "name": "stdout",
373
- "output_type": "stream",
374
- "text": [
375
- "Epoch: 01/25 | Time: 8m 1s | TF Ratio: 0.50\n",
376
- "\tTrain Loss: 4.1408 | Val Loss: 3.8033 | Best Val: 3.8033 ✓ SAVED\n"
377
- ]
378
- },
379
- {
380
- "name": "stderr",
381
- "output_type": "stream",
382
- "text": [
383
- " \r"
384
- ]
385
- },
386
- {
387
- "name": "stdout",
388
- "output_type": "stream",
389
- "text": [
390
- "Epoch: 02/25 | Time: 7m 41s | TF Ratio: 0.45\n",
391
- "\tTrain Loss: 3.0543 | Val Loss: 3.5220 | Best Val: 3.5220 ✓ SAVED\n"
392
- ]
393
- },
394
- {
395
- "name": "stderr",
396
- "output_type": "stream",
397
- "text": [
398
- " \r"
399
- ]
400
- },
401
- {
402
- "name": "stdout",
403
- "output_type": "stream",
404
- "text": [
405
- "Epoch: 03/25 | Time: 7m 40s | TF Ratio: 0.41\n",
406
- "\tTrain Loss: 2.6443 | Val Loss: 3.2353 | Best Val: 3.2353 ✓ SAVED\n"
407
- ]
408
- },
409
- {
410
- "name": "stderr",
411
- "output_type": "stream",
412
- "text": [
413
- " \r"
414
- ]
415
- },
416
- {
417
- "name": "stdout",
418
- "output_type": "stream",
419
- "text": [
420
- "Epoch: 04/25 | Time: 7m 44s | TF Ratio: 0.36\n",
421
- "\tTrain Loss: 2.3818 | Val Loss: 3.1132 | Best Val: 3.1132 ✓ SAVED\n"
422
- ]
423
- },
424
- {
425
- "name": "stderr",
426
- "output_type": "stream",
427
- "text": [
428
- " \r"
429
- ]
430
- },
431
- {
432
- "name": "stdout",
433
- "output_type": "stream",
434
- "text": [
435
- "Epoch: 05/25 | Time: 7m 42s | TF Ratio: 0.33\n",
436
- "\tTrain Loss: 2.2041 | Val Loss: 2.9274 | Best Val: 2.9274 ✓ SAVED\n"
437
- ]
438
- },
439
- {
440
- "name": "stderr",
441
- "output_type": "stream",
442
- "text": [
443
- " \r"
444
- ]
445
- },
446
- {
447
- "name": "stdout",
448
- "output_type": "stream",
449
- "text": [
450
- "Epoch: 06/25 | Time: 7m 36s | TF Ratio: 0.30\n",
451
- "\tTrain Loss: 2.0576 | Val Loss: 2.8356 | Best Val: 2.8356 ✓ SAVED\n"
452
- ]
453
- },
454
- {
455
- "name": "stderr",
456
- "output_type": "stream",
457
- "text": [
458
- " \r"
459
- ]
460
- },
461
- {
462
- "name": "stdout",
463
- "output_type": "stream",
464
- "text": [
465
- "Epoch: 07/25 | Time: 7m 41s | TF Ratio: 0.27\n",
466
- "\tTrain Loss: 1.9377 | Val Loss: 2.8092 | Best Val: 2.8092 ✓ SAVED\n"
467
- ]
468
- },
469
- {
470
- "name": "stderr",
471
- "output_type": "stream",
472
- "text": [
473
- " \r"
474
- ]
475
- },
476
- {
477
- "name": "stdout",
478
- "output_type": "stream",
479
- "text": [
480
- "Epoch: 08/25 | Time: 7m 39s | TF Ratio: 0.24\n",
481
- "\tTrain Loss: 1.8034 | Val Loss: 2.8102 | Best Val: 2.8092 \n"
482
- ]
483
- },
484
- {
485
- "name": "stderr",
486
- "output_type": "stream",
487
- "text": [
488
- " \r"
489
- ]
490
- },
491
- {
492
- "name": "stdout",
493
- "output_type": "stream",
494
- "text": [
495
- "Epoch: 09/25 | Time: 7m 39s | TF Ratio: 0.22\n",
496
- "\tTrain Loss: 1.7125 | Val Loss: 2.7772 | Best Val: 2.7772 ✓ SAVED\n"
497
- ]
498
- },
499
- {
500
- "name": "stderr",
501
- "output_type": "stream",
502
- "text": [
503
- " \r"
504
- ]
505
- },
506
- {
507
- "name": "stdout",
508
- "output_type": "stream",
509
- "text": [
510
- "Epoch: 10/25 | Time: 7m 38s | TF Ratio: 0.19\n",
511
- "\tTrain Loss: 1.6454 | Val Loss: 2.8247 | Best Val: 2.7772 \n"
512
- ]
513
- },
514
- {
515
- "name": "stderr",
516
- "output_type": "stream",
517
- "text": [
518
- " \r"
519
- ]
520
- },
521
- {
522
- "name": "stdout",
523
- "output_type": "stream",
524
- "text": [
525
- "Epoch: 11/25 | Time: 7m 42s | TF Ratio: 0.17\n",
526
- "\tTrain Loss: 1.5686 | Val Loss: 2.8969 | Best Val: 2.7772 \n"
527
- ]
528
- },
529
- {
530
- "name": "stderr",
531
- "output_type": "stream",
532
- "text": [
533
- " \r"
534
- ]
535
- },
536
- {
537
- "ename": "KeyboardInterrupt",
538
- "evalue": "",
539
- "output_type": "error",
540
- "traceback": [
541
- "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
542
- "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
543
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m 7\u001b[39m start_time = time.time()\n\u001b[32m 9\u001b[39m tf_ratio = \u001b[38;5;28mmax\u001b[39m(\u001b[32m0.1\u001b[39m, \u001b[32m0.5\u001b[39m * (\u001b[32m0.9\u001b[39m ** epoch))\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m train_loss = \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCLIP\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtf_ratio\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 11\u001b[39m valid_loss = evaluate(model, val_loader, criterion)\n\u001b[32m 12\u001b[39m scheduler.step(valid_loss)\n",
544
- "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[13]\u001b[39m\u001b[32m, line 17\u001b[39m, in \u001b[36mtrain\u001b[39m\u001b[34m(model, iterator, optimizer, criterion, clip, teacher_forcing_ratio)\u001b[39m\n\u001b[32m 14\u001b[39m trg = trg[\u001b[32m1\u001b[39m:].reshape(-\u001b[32m1\u001b[39m)\n\u001b[32m 16\u001b[39m loss = criterion(output, trg)\n\u001b[32m---> \u001b[39m\u001b[32m17\u001b[39m \u001b[43mloss\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 19\u001b[39m torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n\u001b[32m 21\u001b[39m optimizer.step()\n",
545
- "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/projects/env/lib/python3.12/site-packages/torch/_tensor.py:630\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m 620\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 621\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m 622\u001b[39m Tensor.backward,\n\u001b[32m 623\u001b[39m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m (...)\u001b[39m\u001b[32m 628\u001b[39m inputs=inputs,\n\u001b[32m 629\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m630\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 631\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m 632\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
546
- "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/projects/env/lib/python3.12/site-packages/torch/autograd/__init__.py:364\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m 359\u001b[39m retain_graph = create_graph\n\u001b[32m 361\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m 362\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m 363\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m364\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 365\u001b[39m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 366\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 367\u001b[39m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 368\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 369\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs_tuple\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 370\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 371\u001b[39m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 372\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
547
- "\u001b[36mFile \u001b[39m\u001b[32m~/Desktop/projects/env/lib/python3.12/site-packages/torch/autograd/graph.py:865\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m 863\u001b[39m unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m 864\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m865\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m 866\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 867\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m 868\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 869\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
548
- "\u001b[31mKeyboardInterrupt\u001b[39m: "
549
- ]
550
- }
551
- ],
552
  "source": [
553
  "EPOCHS = 25\n",
554
  "CLIP = 1.0\n",
@@ -597,26 +309,15 @@
597
  },
598
  {
599
  "cell_type": "code",
600
- "execution_count": 32,
601
  "id": "6d9a8e25",
602
  "metadata": {},
603
- "outputs": [
604
- {
605
- "name": "stdout",
606
- "output_type": "stream",
607
- "text": [
608
- "Loaded checkpoint from best_model.pth (epoch 8)\n",
609
- "Sample input (truncated): [CLS:reduction] for (i = 0; i < 1000; ++i)\n",
610
- "{\n",
611
- " logic_and = logic_and && logics[i];\n",
612
- "}\n",
613
- "\n",
614
- "Reference pragma: omp parallel for schedule(dynamic,1) private(i) reduction(&&:logic_and)\n",
615
- "Model prediction: omp parallel for schedule(dynamic,1) private(i) reduction(&&:logic_and)\n"
616
- ]
617
- }
618
- ],
619
  "source": [
 
 
 
 
620
  "import os\n",
621
  "\n",
622
  "checkpoint_path = \"best_model.pth\"\n",
@@ -631,8 +332,8 @@
631
  "SOS_IDX = tokenizer.char2idx['<SOS>']\n",
632
  "EOS_IDX = tokenizer.char2idx['<EOS>']\n",
633
  "\n",
 
634
  "def greedy_generate(code_snippet: str, cls: str = \"parallel\", max_len: int = 80) -> str:\n",
635
- " \"\"\"Greedy decode a pragma for a single code snippet.\"\"\"\n",
636
  " model.eval()\n",
637
  " text = code_snippet if code_snippet.startswith(\"[CLS:\") else f\"[CLS:{cls}] {code_snippet}\"\n",
638
  " input_ids = tokenizer.encode(text, max_length=500, add_special_tokens=True)\n",
@@ -665,13 +366,15 @@
665
  "\n",
666
  " return tokenizer.decode(generated)\n",
667
  "\n",
 
 
668
  "# Quick sanity check on a validation example\n",
669
  "sample_input = val_inputs[18]\n",
670
  "reference = val_outputs[18]\n",
671
- "prediction = greedy_generate(sample_input)\n",
672
  "print(\"Sample input (truncated):\", sample_input[:140] + \"...\" if len(sample_input) > 140 else sample_input)\n",
673
  "print(\"Reference pragma:\", reference)\n",
674
- "print(\"Model prediction:\", prediction)\n"
675
  ]
676
  }
677
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "bae751d8",
7
  "metadata": {},
8
  "outputs": [],
 
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": null,
30
  "id": "c0e30f61",
31
  "metadata": {},
32
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  "source": [
34
  "tokenizer = Tokenizer(vocab_size=8000)\n",
35
  "tokenizer.load(\"tokenizer.json\")"
 
37
  },
38
  {
39
  "cell_type": "code",
40
+ "execution_count": null,
41
  "id": "db130c45",
42
  "metadata": {},
43
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  "source": [
45
  "train_inputs, train_outputs = [], []\n",
46
  "val_inputs, val_outputs = [], []\n",
 
96
  },
97
  {
98
  "cell_type": "code",
99
+ "execution_count": null,
100
  "id": "d5747915",
101
  "metadata": {},
102
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  "source": [
104
  "train_dataset = OpenMPDataset(\n",
105
  " train_inputs, train_outputs, tokenizer,\n",
106
+ " max_input_len=1500,\n",
107
+ " max_output_len=300\n",
108
  ")\n",
109
  "\n",
110
  "val_dataset = OpenMPDataset(\n",
111
  " val_inputs, val_outputs, tokenizer,\n",
112
+ " max_input_len=1500,\n",
113
+ " max_output_len=300\n",
114
  ")\n",
115
  "\n",
116
  "print(f\"\\nDataset shapes:\")\n",
 
122
  },
123
  {
124
  "cell_type": "code",
125
+ "execution_count": null,
126
  "id": "5252d457",
127
  "metadata": {},
128
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  "source": [
130
  "train_loader = DataLoader(\n",
131
  " train_dataset,\n",
132
+ " batch_size=8,\n",
133
  " shuffle=True,\n",
134
  " pin_memory=True\n",
135
  ")\n",
136
  "\n",
137
  "val_loader = DataLoader(\n",
138
  " val_dataset,\n",
139
+ " batch_size=8,\n",
140
  " shuffle=False,\n",
141
  " pin_memory=True\n",
142
  ")\n",
 
155
  },
156
  {
157
  "cell_type": "code",
158
+ "execution_count": null,
159
  "id": "11631bed",
160
  "metadata": {},
161
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  "source": [
163
  "\n",
164
  "VOCAB_SIZE = tokenizer.vocab_size\n",
165
  "EMBED_SIZE = 128\n",
166
  "HIDDEN_SIZE = 256\n",
167
+ "NUM_LAYERS = 3\n",
168
+ "DROPOUT = 0.2\n",
169
  "\n",
170
  "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
171
  "\n",
 
182
  },
183
  {
184
  "cell_type": "code",
185
+ "execution_count": null,
186
  "id": "2d3125a6",
187
  "metadata": {},
188
  "outputs": [],
 
200
  },
201
  {
202
  "cell_type": "code",
203
+ "execution_count": null,
204
  "id": "794c40e7",
205
  "metadata": {},
206
  "outputs": [],
 
257
  },
258
  {
259
  "cell_type": "code",
260
+ "execution_count": null,
261
  "id": "d4bb0e92",
262
  "metadata": {},
263
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  "source": [
265
  "EPOCHS = 25\n",
266
  "CLIP = 1.0\n",
 
309
  },
310
  {
311
  "cell_type": "code",
312
+ "execution_count": null,
313
  "id": "6d9a8e25",
314
  "metadata": {},
315
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  "source": [
317
+ "\n",
318
+ "import sys\n",
319
+ "import pathlib\n",
320
+ "sys.path.append(str(pathlib.Path().resolve())) # ensure local modules are importable\n",
321
  "import os\n",
322
  "\n",
323
  "checkpoint_path = \"best_model.pth\"\n",
 
332
  "SOS_IDX = tokenizer.char2idx['<SOS>']\n",
333
  "EOS_IDX = tokenizer.char2idx['<EOS>']\n",
334
  "\n",
335
+ "# Greedy baseline (kept for comparison)\n",
336
  "def greedy_generate(code_snippet: str, cls: str = \"parallel\", max_len: int = 80) -> str:\n",
 
337
  " model.eval()\n",
338
  " text = code_snippet if code_snippet.startswith(\"[CLS:\") else f\"[CLS:{cls}] {code_snippet}\"\n",
339
  " input_ids = tokenizer.encode(text, max_length=500, add_special_tokens=True)\n",
 
366
  "\n",
367
  " return tokenizer.decode(generated)\n",
368
  "\n",
369
+ "\n",
370
+ "\n",
371
  "# Quick sanity check on a validation example\n",
372
  "sample_input = val_inputs[18]\n",
373
  "reference = val_outputs[18]\n",
374
+ "prediction_greedy = greedy_generate(sample_input)\n",
375
  "print(\"Sample input (truncated):\", sample_input[:140] + \"...\" if len(sample_input) > 140 else sample_input)\n",
376
  "print(\"Reference pragma:\", reference)\n",
377
+ "print(\"Greedy prediction:\", prediction_greedy)"
378
  ]
379
  }
380
  ],