mohamedahraf273 commited on
Commit
e8aab00
·
1 Parent(s): 94ee9c6

add generator

Browse files
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c87d759052debb4e4adb62ef51c9d65671d04bfc6e1f9fd4b2130c66e69b9257
3
+ size 162038291
dataset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+
4
+ class OpenMPDataset(Dataset):
5
+ def __init__(self, inputs, outputs, tokenizer, max_input_len=500, max_output_len=100):
6
+ self.inputs = inputs
7
+ self.outputs = outputs
8
+ self.tokenizer = tokenizer
9
+ self.max_input_len = max_input_len
10
+ self.max_output_len = max_output_len
11
+ self.pad_idx = tokenizer.char2idx['<PAD>']
12
+
13
+ def __len__(self):
14
+ return len(self.inputs)
15
+
16
+ def __getitem__(self, idx):
17
+ input_ids = self.tokenizer.encode(
18
+ self.inputs[idx],
19
+ self.max_input_len,
20
+ add_special_tokens=True
21
+ )
22
+ output_ids = self.tokenizer.encode(
23
+ self.outputs[idx],
24
+ self.max_output_len,
25
+ add_special_tokens=True
26
+ )
27
+
28
+ input_len = next(
29
+ (i for i, tok in enumerate(input_ids) if tok == self.pad_idx),
30
+ self.max_input_len
31
+ )
32
+ output_len = next(
33
+ (i for i, tok in enumerate(output_ids) if tok == self.pad_idx),
34
+ self.max_output_len
35
+ )
36
+
37
+ return {
38
+ 'input': torch.tensor(input_ids, dtype=torch.long),
39
+ 'output': torch.tensor(output_ids, dtype=torch.long),
40
+ 'input_len': torch.tensor(input_len, dtype=torch.long),
41
+ 'output_len': torch.tensor(output_len, dtype=torch.long)
42
+ }
generator.ipynb ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "bae751d8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json\n",
11
+ "import torch\n",
12
+ "import torch.nn as nn\n",
13
+ "import torch.optim as optim\n",
14
+ "import time\n",
15
+ "from tqdm import tqdm\n",
16
+ "\n",
17
+ "from torch.utils.data import DataLoader\n",
18
+ "from models.open_mp_gen.tokenizer import Tokenizer\n",
19
+ "from models.open_mp_gen.model.generator import Generator\n",
20
+ "from models.open_mp_gen.model.encoder import Encoder\n",
21
+ "from models.open_mp_gen.model.decoder import Decoder\n",
22
+ "from models.open_mp_gen.model.attn import BahdanauAttention\n",
23
+ "from models.open_mp_gen.dataset import OpenMPDataset\n",
24
+ "from accelera.src.utils.code_utils import pragma_to_class"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 3,
30
+ "id": "c0e30f61",
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stdout",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "BPE Tokenizer loaded from tokenizer.json\n",
38
+ " - Vocab size: 8002\n",
39
+ " - BPE merges: 7888\n"
40
+ ]
41
+ },
42
+ {
43
+ "data": {
44
+ "text/plain": [
45
+ "<models.open_mp_gen.tokenizer.Tokenizer at 0x7a60237a60c0>"
46
+ ]
47
+ },
48
+ "execution_count": 3,
49
+ "metadata": {},
50
+ "output_type": "execute_result"
51
+ }
52
+ ],
53
+ "source": [
54
+ "tokenizer = Tokenizer(vocab_size=8000)\n",
55
+ "tokenizer.load(\"tokenizer.json\")"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 4,
61
+ "id": "db130c45",
62
+ "metadata": {},
63
+ "outputs": [
64
+ {
65
+ "name": "stdout",
66
+ "output_type": "stream",
67
+ "text": [
68
+ "Training samples: 15671\n",
69
+ "Validation samples: 1684\n",
70
+ "\n",
71
+ "Sample input (first 70 chars):\n",
72
+ "[CLS:parallel_for] for (int ix = 1; ix < (N + 1); ix++)\n",
73
+ "{\n",
74
+ " forces[ix] = forces[ix] * force_retention;\n",
75
+ "}\n",
76
+ "\n",
77
+ "Sample output:\n",
78
+ "omp parallel for\n"
79
+ ]
80
+ }
81
+ ],
82
+ "source": [
83
+ "train_inputs, train_outputs = [], []\n",
84
+ "val_inputs, val_outputs = [], []\n",
85
+ "\n",
86
+ "with open('../data/data.json', 'r') as f:\n",
87
+ " lines = f.readlines()\n",
88
+ " \n",
89
+ " split_idx = int(0.9 * len(lines))\n",
90
+ " train_lines = lines[:split_idx]\n",
91
+ " val_lines = lines[split_idx:]\n",
92
+ "\n",
93
+ "for line in train_lines:\n",
94
+ " item = json.loads(line.strip())\n",
95
+ " \n",
96
+ " if item['label'] == 'False':\n",
97
+ " continue\n",
98
+ " \n",
99
+ " cls = pragma_to_class(item['label'], item['pragma'])\n",
100
+ " if cls == 'none':\n",
101
+ " continue\n",
102
+ " \n",
103
+ " input_str = f\"[CLS:{cls}] {item['code']}\"\n",
104
+ " output_str = item['pragma'].strip()\n",
105
+ " \n",
106
+ " if not output_str:\n",
107
+ " continue\n",
108
+ " \n",
109
+ " train_inputs.append(input_str)\n",
110
+ " train_outputs.append(output_str)\n",
111
+ "\n",
112
+ "for line in val_lines:\n",
113
+ " item = json.loads(line.strip())\n",
114
+ " if item['label'] == 'False':\n",
115
+ " continue\n",
116
+ " \n",
117
+ " cls = pragma_to_class(item['label'], item['pragma'])\n",
118
+ " if cls == 'none':\n",
119
+ " continue\n",
120
+ " \n",
121
+ " input_str = f\"[CLS:{cls}] {item['code']}\"\n",
122
+ " output_str = item['pragma'].strip()\n",
123
+ " if not output_str:\n",
124
+ " continue\n",
125
+ " \n",
126
+ " val_inputs.append(input_str)\n",
127
+ " val_outputs.append(output_str)\n",
128
+ "\n",
129
+ "print(f\"Training samples: {len(train_inputs)}\")\n",
130
+ "print(f\"Validation samples: {len(val_inputs)}\")\n",
131
+ "print(f\"\\nSample input (first 70 chars):\\n{train_inputs[0]}\")\n",
132
+ "print(f\"Sample output:\\n{train_outputs[0]}\")"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 5,
138
+ "id": "d5747915",
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "\n",
146
+ "Dataset shapes:\n",
147
+ " Train: 15671 samples\n",
148
+ " Val: 1684 samples\n",
149
+ " Sample input tensor shape: torch.Size([500])\n",
150
+ " Sample output tensor shape: torch.Size([100])\n"
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "train_dataset = OpenMPDataset(\n",
156
+ " train_inputs, train_outputs, tokenizer,\n",
157
+ " max_input_len=500,\n",
158
+ " max_output_len=100\n",
159
+ ")\n",
160
+ "\n",
161
+ "val_dataset = OpenMPDataset(\n",
162
+ " val_inputs, val_outputs, tokenizer,\n",
163
+ " max_input_len=500,\n",
164
+ " max_output_len=100\n",
165
+ ")\n",
166
+ "\n",
167
+ "print(f\"\\nDataset shapes:\")\n",
168
+ "print(f\" Train: {len(train_dataset)} samples\")\n",
169
+ "print(f\" Val: {len(val_dataset)} samples\")\n",
170
+ "print(f\" Sample input tensor shape: {train_dataset[0]['input'].shape}\")\n",
171
+ "print(f\" Sample output tensor shape: {train_dataset[0]['output'].shape}\")"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 6,
177
+ "id": "5252d457",
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "\n",
185
+ "✓ Dataloaders ready!\n",
186
+ " Train batches: 490\n",
187
+ " Val batches: 53\n",
188
+ "\n",
189
+ "Sample batch structure:\n",
190
+ " input shape: torch.Size([32, 500])\n",
191
+ " output shape: torch.Size([32, 100])\n",
192
+ " input_len shape: torch.Size([32])\n",
193
+ " First sample input_len: 16\n",
194
+ "\n",
195
+ "Sample batch structure:\n",
196
+ " input shape: torch.Size([32, 500])\n",
197
+ " output shape: torch.Size([32, 100])\n",
198
+ " input_len shape: torch.Size([32])\n",
199
+ " First sample input_len: 16\n"
200
+ ]
201
+ }
202
+ ],
203
+ "source": [
204
+ "train_loader = DataLoader(\n",
205
+ " train_dataset,\n",
206
+ " batch_size=32,\n",
207
+ " shuffle=True,\n",
208
+ " pin_memory=True\n",
209
+ ")\n",
210
+ "\n",
211
+ "val_loader = DataLoader(\n",
212
+ " val_dataset,\n",
213
+ " batch_size=32,\n",
214
+ " shuffle=False,\n",
215
+ " pin_memory=True\n",
216
+ ")\n",
217
+ "\n",
218
+ "print(f\"\\n✓ Dataloaders ready!\")\n",
219
+ "print(f\" Train batches: {len(train_loader)}\")\n",
220
+ "print(f\" Val batches: {len(val_loader)}\")\n",
221
+ "\n",
222
+ "sample_batch = next(iter(train_loader))\n",
223
+ "print(f\"\\nSample batch structure:\")\n",
224
+ "print(f\" input shape: {sample_batch['input'].shape}\")\n",
225
+ "print(f\" output shape: {sample_batch['output'].shape}\")\n",
226
+ "print(f\" input_len shape: {sample_batch['input_len'].shape}\")\n",
227
+ "print(f\" First sample input_len: {sample_batch['input_len'][0]}\")"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 7,
233
+ "id": "11631bed",
234
+ "metadata": {},
235
+ "outputs": [
236
+ {
237
+ "name": "stdout",
238
+ "output_type": "stream",
239
+ "text": [
240
+ "Model architecture:\n",
241
+ "Generator(\n",
242
+ " (encoder): Encoder(\n",
243
+ " (embedding): Embedding(8002, 128, padding_idx=0)\n",
244
+ " (lstm): LSTM(128, 256, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)\n",
245
+ " (dropout): Dropout(p=0.3, inplace=False)\n",
246
+ " )\n",
247
+ " (decoder): Decoder(\n",
248
+ " (attention): BahdanauAttention(\n",
249
+ " (W1): Linear(in_features=512, out_features=256, bias=True)\n",
250
+ " (W2): Linear(in_features=256, out_features=256, bias=True)\n",
251
+ " (V): Linear(in_features=256, out_features=1, bias=True)\n",
252
+ " )\n",
253
+ " (embedding): Embedding(8002, 128, padding_idx=0)\n",
254
+ " (lstm): LSTM(640, 256, num_layers=2, batch_first=True, dropout=0.3)\n",
255
+ " (fc_out): Linear(in_features=896, out_features=8002, bias=True)\n",
256
+ " (dropout): Dropout(p=0.3, inplace=False)\n",
257
+ " )\n",
258
+ " (hidden_projection): Linear(in_features=512, out_features=256, bias=True)\n",
259
+ " (cell_projection): Linear(in_features=512, out_features=256, bias=True)\n",
260
+ ")\n",
261
+ "\n",
262
+ "Total parameters: 13,499,715\n"
263
+ ]
264
+ }
265
+ ],
266
+ "source": [
267
+ "\n",
268
+ "VOCAB_SIZE = tokenizer.vocab_size\n",
269
+ "EMBED_SIZE = 128\n",
270
+ "HIDDEN_SIZE = 256\n",
271
+ "NUM_LAYERS = 2\n",
272
+ "DROPOUT = 0.3\n",
273
+ "\n",
274
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
275
+ "\n",
276
+ "encoder = Encoder(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE, NUM_LAYERS, DROPOUT)\n",
277
+ "attention = BahdanauAttention(HIDDEN_SIZE)\n",
278
+ "decoder = Decoder(VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE, attention, NUM_LAYERS, DROPOUT)\n",
279
+ "model = Generator(encoder, decoder, device).to(device)\n",
280
+ "model.apply(model._init_weights)\n",
281
+ "\n",
282
+ "print(\"Model architecture:\")\n",
283
+ "print(model)\n",
284
+ "print(f\"\\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}\")"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": 9,
290
+ "id": "2d3125a6",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "PAD_IDX = tokenizer.char2idx['<PAD>']\n",
295
+ "criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n",
296
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
297
+ "scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n",
298
+ " optimizer, \n",
299
+ " mode='min', \n",
300
+ " factor=0.5, \n",
301
+ " patience=2, \n",
302
+ ")"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "794c40e7",
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "def train(model, iterator, optimizer, criterion, clip=1.0, teacher_forcing_ratio=0.5):\n",
313
+ " model.train()\n",
314
+ " epoch_loss = 0\n",
315
+ " \n",
316
+ " for batch in tqdm(iterator, desc=\"Training\", leave=False):\n",
317
+ " src = batch['input'].to(device)\n",
318
+ " trg = batch['output'].to(device)\n",
319
+ " src_len = batch['input_len'].to(device)\n",
320
+ " optimizer.zero_grad()\n",
321
+ " output = model(src, src_len, trg, teacher_forcing_ratio)\n",
322
+ " output_dim = output.shape[-1]\n",
323
+ " output = output[1:].view(-1, output_dim)\n",
324
+ " trg = trg.transpose(0, 1) \n",
325
+ " trg = trg[1:].reshape(-1)\n",
326
+ " \n",
327
+ " loss = criterion(output, trg)\n",
328
+ " loss.backward()\n",
329
+ " \n",
330
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
331
+ " \n",
332
+ " optimizer.step()\n",
333
+ " epoch_loss += loss.item()\n",
334
+ " \n",
335
+ " return epoch_loss / len(iterator)\n",
336
+ "\n",
337
+ "\n",
338
+ "def evaluate(model, iterator, criterion):\n",
339
+ " model.eval()\n",
340
+ " epoch_loss = 0\n",
341
+ " \n",
342
+ " with torch.no_grad():\n",
343
+ " for batch in tqdm(iterator, desc=\"Evaluating\", leave=False):\n",
344
+ " src = batch['input'].to(device)\n",
345
+ " trg = batch['output'].to(device)\n",
346
+ " src_len = batch['input_len'].to(device)\n",
347
+ " \n",
348
+ " output = model(src, src_len, trg, 0)\n",
349
+ " \n",
350
+ " output_dim = output.shape[-1]\n",
351
+ " output = output[1:].view(-1, output_dim)\n",
352
+ " \n",
353
+ " trg = trg.transpose(0, 1)\n",
354
+ " trg = trg[1:].reshape(-1)\n",
355
+ " \n",
356
+ " loss = criterion(output, trg)\n",
357
+ " epoch_loss += loss.item()\n",
358
+ " \n",
359
+ " return epoch_loss / len(iterator)"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 11,
365
+ "id": "d4bb0e92",
366
+ "metadata": {},
367
+ "outputs": [
368
+ {
369
+ "name": "stderr",
370
+ "output_type": "stream",
371
+ "text": [
372
+ " "
373
+ ]
374
+ },
375
+ {
376
+ "name": "stdout",
377
+ "output_type": "stream",
378
+ "text": [
379
+ "Epoch: 01/15 | Time: 7m 39s | TF Ratio: 0.50\n",
380
+ "\tTrain Loss: 4.5316 | Val Loss: 4.2697 | Best Val: 4.2697 ✓ SAVED\n"
381
+ ]
382
+ },
383
+ {
384
+ "name": "stderr",
385
+ "output_type": "stream",
386
+ "text": [
387
+ " "
388
+ ]
389
+ },
390
+ {
391
+ "name": "stdout",
392
+ "output_type": "stream",
393
+ "text": [
394
+ "Epoch: 02/15 | Time: 7m 33s | TF Ratio: 0.45\n",
395
+ "\tTrain Loss: 3.6810 | Val Loss: 4.0286 | Best Val: 4.0286 ✓ SAVED\n"
396
+ ]
397
+ },
398
+ {
399
+ "name": "stderr",
400
+ "output_type": "stream",
401
+ "text": [
402
+ " "
403
+ ]
404
+ },
405
+ {
406
+ "name": "stdout",
407
+ "output_type": "stream",
408
+ "text": [
409
+ "Epoch: 03/15 | Time: 7m 40s | TF Ratio: 0.41\n",
410
+ "\tTrain Loss: 3.4275 | Val Loss: 3.8817 | Best Val: 3.8817 ✓ SAVED\n"
411
+ ]
412
+ },
413
+ {
414
+ "name": "stderr",
415
+ "output_type": "stream",
416
+ "text": [
417
+ " "
418
+ ]
419
+ },
420
+ {
421
+ "name": "stdout",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "Epoch: 04/15 | Time: 7m 40s | TF Ratio: 0.36\n",
425
+ "\tTrain Loss: 3.2257 | Val Loss: 3.7254 | Best Val: 3.7254 ✓ SAVED\n"
426
+ ]
427
+ },
428
+ {
429
+ "name": "stderr",
430
+ "output_type": "stream",
431
+ "text": [
432
+ " "
433
+ ]
434
+ },
435
+ {
436
+ "name": "stdout",
437
+ "output_type": "stream",
438
+ "text": [
439
+ "Epoch: 05/15 | Time: 7m 38s | TF Ratio: 0.33\n",
440
+ "\tTrain Loss: 3.0585 | Val Loss: 3.6210 | Best Val: 3.6210 ✓ SAVED\n"
441
+ ]
442
+ },
443
+ {
444
+ "name": "stderr",
445
+ "output_type": "stream",
446
+ "text": [
447
+ " "
448
+ ]
449
+ },
450
+ {
451
+ "name": "stdout",
452
+ "output_type": "stream",
453
+ "text": [
454
+ "Epoch: 06/15 | Time: 7m 37s | TF Ratio: 0.30\n",
455
+ "\tTrain Loss: 2.9102 | Val Loss: 3.4103 | Best Val: 3.4103 ✓ SAVED\n"
456
+ ]
457
+ },
458
+ {
459
+ "name": "stderr",
460
+ "output_type": "stream",
461
+ "text": [
462
+ " "
463
+ ]
464
+ },
465
+ {
466
+ "name": "stdout",
467
+ "output_type": "stream",
468
+ "text": [
469
+ "Epoch: 07/15 | Time: 7m 39s | TF Ratio: 0.27\n",
470
+ "\tTrain Loss: 2.7814 | Val Loss: 3.3304 | Best Val: 3.3304 ✓ SAVED\n"
471
+ ]
472
+ },
473
+ {
474
+ "name": "stderr",
475
+ "output_type": "stream",
476
+ "text": [
477
+ " "
478
+ ]
479
+ },
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "Epoch: 08/15 | Time: 7m 38s | TF Ratio: 0.24\n",
485
+ "\tTrain Loss: 2.6669 | Val Loss: 3.2644 | Best Val: 3.2644 ✓ SAVED\n"
486
+ ]
487
+ },
488
+ {
489
+ "name": "stderr",
490
+ "output_type": "stream",
491
+ "text": [
492
+ " "
493
+ ]
494
+ },
495
+ {
496
+ "name": "stdout",
497
+ "output_type": "stream",
498
+ "text": [
499
+ "Epoch: 09/15 | Time: 7m 38s | TF Ratio: 0.22\n",
500
+ "\tTrain Loss: 2.5686 | Val Loss: 3.2038 | Best Val: 3.2038 ✓ SAVED\n"
501
+ ]
502
+ },
503
+ {
504
+ "name": "stderr",
505
+ "output_type": "stream",
506
+ "text": [
507
+ " "
508
+ ]
509
+ },
510
+ {
511
+ "name": "stdout",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "Epoch: 10/15 | Time: 7m 38s | TF Ratio: 0.19\n",
515
+ "\tTrain Loss: 2.4794 | Val Loss: 3.0976 | Best Val: 3.0976 ✓ SAVED\n"
516
+ ]
517
+ },
518
+ {
519
+ "name": "stderr",
520
+ "output_type": "stream",
521
+ "text": [
522
+ " "
523
+ ]
524
+ },
525
+ {
526
+ "name": "stdout",
527
+ "output_type": "stream",
528
+ "text": [
529
+ "Epoch: 11/15 | Time: 7m 37s | TF Ratio: 0.17\n",
530
+ "\tTrain Loss: 2.4153 | Val Loss: 3.0713 | Best Val: 3.0713 ✓ SAVED\n"
531
+ ]
532
+ },
533
+ {
534
+ "name": "stderr",
535
+ "output_type": "stream",
536
+ "text": [
537
+ " "
538
+ ]
539
+ },
540
+ {
541
+ "name": "stdout",
542
+ "output_type": "stream",
543
+ "text": [
544
+ "Epoch: 12/15 | Time: 7m 35s | TF Ratio: 0.16\n",
545
+ "\tTrain Loss: 2.3247 | Val Loss: 2.9971 | Best Val: 2.9971 ✓ SAVED\n"
546
+ ]
547
+ },
548
+ {
549
+ "name": "stderr",
550
+ "output_type": "stream",
551
+ "text": [
552
+ " "
553
+ ]
554
+ },
555
+ {
556
+ "name": "stdout",
557
+ "output_type": "stream",
558
+ "text": [
559
+ "Epoch: 13/15 | Time: 7m 38s | TF Ratio: 0.14\n",
560
+ "\tTrain Loss: 2.2682 | Val Loss: 2.9529 | Best Val: 2.9529 ✓ SAVED\n"
561
+ ]
562
+ },
563
+ {
564
+ "name": "stderr",
565
+ "output_type": "stream",
566
+ "text": [
567
+ " "
568
+ ]
569
+ },
570
+ {
571
+ "name": "stdout",
572
+ "output_type": "stream",
573
+ "text": [
574
+ "Epoch: 14/15 | Time: 7m 38s | TF Ratio: 0.13\n",
575
+ "\tTrain Loss: 2.2045 | Val Loss: 2.9489 | Best Val: 2.9489 ✓ SAVED\n"
576
+ ]
577
+ },
578
+ {
579
+ "name": "stderr",
580
+ "output_type": "stream",
581
+ "text": [
582
+ " "
583
+ ]
584
+ },
585
+ {
586
+ "name": "stdout",
587
+ "output_type": "stream",
588
+ "text": [
589
+ "Epoch: 15/15 | Time: 7m 39s | TF Ratio: 0.11\n",
590
+ "\tTrain Loss: 2.1487 | Val Loss: 2.9050 | Best Val: 2.9050 ✓ SAVED\n",
591
+ "\n",
592
+ "======================================================================\n",
593
+ "✓ TRAINING COMPLETE!\n",
594
+ "Best validation loss: 2.9050\n",
595
+ "Model saved to 'best_model.pth'\n",
596
+ "======================================================================\n"
597
+ ]
598
+ }
599
+ ],
600
+ "source": [
601
+ "EPOCHS = 15\n",
602
+ "CLIP = 1.0\n",
603
+ "best_valid_loss = float('inf')\n",
604
+ "training_history = {'train_loss': [], 'valid_loss': []}\n",
605
+ "\n",
606
+ "for epoch in range(EPOCHS):\n",
607
+ " start_time = time.time()\n",
608
+ " \n",
609
+ " tf_ratio = max(0.1, 0.5 * (0.9 ** epoch))\n",
610
+ " train_loss = train(model, train_loader, optimizer, criterion, CLIP, tf_ratio)\n",
611
+ " valid_loss = evaluate(model, val_loader, criterion)\n",
612
+ " scheduler.step(valid_loss)\n",
613
+ " if valid_loss < best_valid_loss:\n",
614
+ " best_valid_loss = valid_loss\n",
615
+ " torch.save({\n",
616
+ " 'epoch': epoch,\n",
617
+ " 'model_state_dict': model.state_dict(),\n",
618
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
619
+ " 'valid_loss': valid_loss,\n",
620
+ " 'vocab_size': VOCAB_SIZE,\n",
621
+ " 'embed_size': EMBED_SIZE,\n",
622
+ " 'hidden_size': HIDDEN_SIZE,\n",
623
+ " 'num_layers': NUM_LAYERS\n",
624
+ " }, 'best_model.pth')\n",
625
+ " save_status = \"✓ SAVED\"\n",
626
+ " else:\n",
627
+ " save_status = \" \"\n",
628
+ " \n",
629
+ " training_history['train_loss'].append(train_loss)\n",
630
+ " training_history['valid_loss'].append(valid_loss)\n",
631
+ " \n",
632
+ " end_time = time.time()\n",
633
+ " epoch_mins = int((end_time - start_time) / 60)\n",
634
+ " epoch_secs = int((end_time - start_time) % 60)\n",
635
+ " \n",
636
+ " print(f'Epoch: {epoch+1:02}/{EPOCHS} | Time: {epoch_mins}m {epoch_secs}s | TF Ratio: {tf_ratio:.2f}')\n",
637
+ " print(f'\\tTrain Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Best Val: {best_valid_loss:.4f} {save_status}')\n",
638
+ "\n",
639
+ "print(\"\\n\" + \"=\"*70)\n",
640
+ "print(f\"✓ TRAINING COMPLETE!\")\n",
641
+ "print(f\"Best validation loss: {best_valid_loss:.4f}\")\n",
642
+ "print(f\"Model saved to 'best_model.pth'\")\n",
643
+ "print(\"=\"*70)"
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": 15,
649
+ "id": "a49bb85f",
650
+ "metadata": {},
651
+ "outputs": [
652
+ {
653
+ "name": "stdout",
654
+ "output_type": "stream",
655
+ "text": [
656
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
657
+ "IMPORTANT: If you haven't re-run the TRAINING loop (Cell 9)\n",
658
+ "after applying the Transpose fix, the results below will likely\n",
659
+ "be poor/incomplete because the model hasn't updated its weights\n",
660
+ "correctly yet.\n",
661
+ "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n",
662
+ "\n",
663
+ "Running generation tests on validation set (True Greedy Decoding):\n",
664
+ "\n",
665
+ "Example 0:\n",
666
+ "Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
667
+ " ;\n",
668
+ "\n",
669
+ "Target: omp target parallel for simd simdlen(4 4)\n",
670
+ "Prediction: omp parallel for shared(,k,,,,,,,,,,,,,pr) shared(L,,,,,,,,,,,,,,,,\n",
671
+ "------------------------------------------------------------\n",
672
+ "Example 10:\n",
673
+ "Input: [CLS:reduction] for (i = 1; i < (500 - 1); i++)\n",
674
+ "{\n",
675
+ " iIndex = i * dim2;\n",
676
+ " jIndex = 0;\n",
677
+ " for (j = 1; j < (500 - 1); j++)\n",
678
+ " {\n",
679
+ " jIndex += 500;\n",
680
+ " for (k = 1; k < (500 - 1); k++)\n",
681
+ " {\n",
682
+ " index = (iIndex + jIndex) + k;\n",
683
+ " compute_it = old[index] * need;\n",
684
+ " aggregate += compute_it / gimmie;\n",
685
+ " accumulator = 0;\n",
686
+ " long subsum1 = 0;\n",
687
+ " long subsum2 = 0;\n",
688
+ " long subsum3 = 0;\n",
689
+ " for (z = 0; z < 27; z += 3)\n",
690
+ " {\n",
691
+ " subsum1 += old[index + arr[z]];\n",
692
+ " subsum2 += old[index + arr[z + 1]];\n",
693
+ " subsum3 += old[index + arr[z + 2]];\n",
694
+ " }\n",
695
+ "\n",
696
+ " accumulator += (subsum1 + subsum2) + subsum3;\n",
697
+ " long value = accumulator / 27;\n",
698
+ " int par = value / 100;\n",
699
+ " a0 += ((unsigned) par) >> 31;\n",
700
+ " a0 += !(par ^ 0);\n",
701
+ " a1 += !(par ^ 1);\n",
702
+ " a2 += !(par ^ 2);\n",
703
+ " a3 += !(par ^ 3);\n",
704
+ " a4 += !(par ^ 4);\n",
705
+ " a5 += !(par ^ 5);\n",
706
+ " a6 += !(par ^ 6);\n",
707
+ " a7 += !(par ^ 7);\n",
708
+ " a8 += !(par ^ 8);\n",
709
+ " int64_t tmp = ((int64_t) par) - 9;\n",
710
+ " a9 += (tmp >> 63) + 1;\n",
711
+ " new[index] = value;\n",
712
+ " }\n",
713
+ "\n",
714
+ " }\n",
715
+ "\n",
716
+ "}\n",
717
+ "\n",
718
+ "Target: omp parallel for private(j, k, z, accumulator, jIndex, index, iIndex, compute_it) reduction(+: aggregate, a0,a1,a2,a3,a4,a5,a6,a7,a8,a9)\n",
719
+ "Prediction: omp parallel for reduction(+:data,,,,,,,,,,,,,,\n",
720
+ "------------------------------------------------------------\n",
721
+ "Example 20:\n",
722
+ "Input: [CLS:parallel_for] for (i = 0; i < 16; ++i)\n",
723
+ " ;\n",
724
+ "\n",
725
+ "Target: omp parallel for simd firstprivate(, )\n",
726
+ "Prediction: omp parallel for shared(,k,,,,,,,,,,,,,pr) shared(L,,,,,,,,,,,,,,,,\n",
727
+ "------------------------------------------------------------\n",
728
+ "Example 30:\n",
729
+ "Input: [CLS:parallel_for] for (i = 0; i < n; i++)\n",
730
+ "{\n",
731
+ " x[i] = 1.0;\n",
732
+ " y[i] = 2.0;\n",
733
+ "}\n",
734
+ "\n",
735
+ "Target: omp parallel for private(i)\n",
736
+ "Prediction: omp parallel for shared(gen,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
737
+ "------------------------------------------------------------\n"
738
+ ]
739
+ }
740
+ ],
741
+ "source": [
742
+ "model.eval()\n",
743
+ "\n",
744
+ "def generate_sentence(model, input_text, tokenizer, max_len=150, device='cuda'):\n",
745
+ " \"\"\"\n",
746
+ " Greedy decoding function that generates tokens until <EOS> or max_len.\n",
747
+ " This mimics the model's forward pass but allows dynamic length generation.\n",
748
+ " \"\"\"\n",
749
+ " model.eval()\n",
750
+ " \n",
751
+ " # Tokenize input\n",
752
+ " input_ids = tokenizer.encode(input_text, max_length=500, add_special_tokens=True)\n",
753
+ " src_tensor = torch.LongTensor(input_ids).unsqueeze(0).to(device) # [1, src_len]\n",
754
+ " src_len = torch.LongTensor([len(input_ids)]).to(device) # [1]\n",
755
+ " \n",
756
+ " with torch.no_grad():\n",
757
+ " # Encode\n",
758
+ " encoder_outputs, hidden, cell = model.encoder(src_tensor, src_len)\n",
759
+ " \n",
760
+ " # Create mask (same logic as in Generator.forward)\n",
761
+ " max_src_len = encoder_outputs.shape[1]\n",
762
+ " mask = torch.arange(max_src_len, device=device).unsqueeze(0) < src_len.unsqueeze(1)\n",
763
+ " mask = mask.float()\n",
764
+ " \n",
765
+ " # Project hidden/cell states from Encoder to Decoder size\n",
766
+ " # Reshape to [num_layers, 2, batch, hidden] to combine bidirectional states\n",
767
+ " hidden = hidden.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
768
+ " hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)\n",
769
+ " hidden = model.hidden_projection(hidden)\n",
770
+ " \n",
771
+ " cell = cell.view(model.encoder.num_layers, 2, 1, model.encoder.hidden_size)\n",
772
+ " cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)\n",
773
+ " cell = model.cell_projection(cell)\n",
774
+ " \n",
775
+ " # Start with <SOS>\n",
776
+ " trg_indexes = [tokenizer.char2idx['<SOS>']]\n",
777
+ " \n",
778
+ " for i in range(max_len):\n",
779
+ " trg_tensor = torch.LongTensor([trg_indexes[-1]]).to(device) # [1]\n",
780
+ " \n",
781
+ " output, hidden, cell, _ = model.decoder(\n",
782
+ " trg_tensor, hidden, cell, encoder_outputs, mask\n",
783
+ " )\n",
784
+ " \n",
785
+ " # Greedy prediction: take token with highest probability\n",
786
+ " pred_token = output.argmax(1).item()\n",
787
+ " trg_indexes.append(pred_token)\n",
788
+ " \n",
789
+ " if pred_token == tokenizer.char2idx['<EOS>']:\n",
790
+ " break\n",
791
+ " \n",
792
+ " # Decode integers back to string\n",
793
+ " return tokenizer.decode(trg_indexes)\n",
794
+ "\n",
795
+ "# ---------------------------------------------------------\n",
796
+ "print(\"!\"*60)\n",
797
+ "print(\"IMPORTANT: If you haven't re-run the TRAINING loop (Cell 9)\")\n",
798
+ "print(\"after applying the Transpose fix, the results below will likely\")\n",
799
+ "print(\"be poor/incomplete because the model hasn't updated its weights\")\n",
800
+ "print(\"correctly yet.\")\n",
801
+ "print(\"!\"*60 + \"\\n\")\n",
802
+ "\n",
803
+ "print(\"Running generation tests on validation set (True Greedy Decoding):\\n\")\n",
804
+ "test_indices = [0, 10, 20, 30]\n",
805
+ "# Ensure indices are within bounds\n",
806
+ "test_indices = [i for i in test_indices if i < len(val_inputs)]\n",
807
+ "\n",
808
+ "for i in test_indices:\n",
809
+ " input_text = val_inputs[i]\n",
810
+ " target_text = val_outputs[i]\n",
811
+ " \n",
812
+ " prediction = generate_sentence(model, input_text, tokenizer, device=device)\n",
813
+ " \n",
814
+ " print(f\"Example {i}:\")\n",
815
+ " print(f\"Input: {input_text}\")\n",
816
+ " print(f\"Target: {target_text}\")\n",
817
+ " print(f\"Prediction: {prediction}\")\n",
818
+ " print(\"-\" * 60)"
819
+ ]
820
+ },
821
+ {
822
+ "cell_type": "code",
823
+ "execution_count": null,
824
+ "id": "85bd9571",
825
+ "metadata": {},
826
+ "outputs": [],
827
+ "source": [
828
+ "# ---------------------------------------------------------\n",
829
+ "# RUN THIS CELL ONLY IF YOU WANT TO RESET TRAINING\n",
830
+ "# This initializes the model weights from scratch. \n",
831
+ "# Run this, and then run the TRAINING LOOP (Cell 9) again.\n",
832
+ "# ---------------------------------------------------------\n",
833
+ "\n",
834
+ "print(\"↺ RESETTING MODEL & OPTIMIZER...\")\n",
835
+ "model = Generator(encoder, decoder, device).to(device)\n",
836
+ "model.apply(model._init_weights)\n",
837
+ "\n",
838
+ "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
839
+ "training_history = {'train_loss': [], 'valid_loss': []}\n",
840
+ "best_valid_loss = float('inf')\n",
841
+ "\n",
842
+ "print(\"✓ Model reset. Now scroll up and run the TRAINING LOOP again.\")"
843
+ ]
844
+ }
845
+ ],
846
+ "metadata": {
847
+ "kernelspec": {
848
+ "display_name": "env",
849
+ "language": "python",
850
+ "name": "python3"
851
+ },
852
+ "language_info": {
853
+ "codemirror_mode": {
854
+ "name": "ipython",
855
+ "version": 3
856
+ },
857
+ "file_extension": ".py",
858
+ "mimetype": "text/x-python",
859
+ "name": "python",
860
+ "nbconvert_exporter": "python",
861
+ "pygments_lexer": "ipython3",
862
+ "version": "3.12.3"
863
+ }
864
+ },
865
+ "nbformat": 4,
866
+ "nbformat_minor": 5
867
+ }
model/__pycache__/attn.cpython-312.pyc ADDED
Binary file (2.35 kB). View file
 
model/__pycache__/decoder.cpython-312.pyc ADDED
Binary file (3.12 kB). View file
 
model/__pycache__/encoder.cpython-312.pyc ADDED
Binary file (2.48 kB). View file
 
model/__pycache__/generator.cpython-312.pyc ADDED
Binary file (5.35 kB). View file
 
model/attn.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional
5
+
6
+ class BahdanauAttention(nn.Module):
7
+ def __init__(self, hidden_size: int):
8
+ super(BahdanauAttention, self).__init__()
9
+ self.W1 = nn.Linear(hidden_size * 2, hidden_size)
10
+ self.W2 = nn.Linear(hidden_size, hidden_size)
11
+ self.V = nn.Linear(hidden_size, 1)
12
+
13
+ def forward(
14
+ self,
15
+ decoder_hidden: torch.Tensor,
16
+ encoder_outputs: torch.Tensor,
17
+ mask: Optional[torch.Tensor] = None
18
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ hidden_expanded = decoder_hidden.unsqueeze(1)
20
+ score = torch.tanh(
21
+ self.W1(encoder_outputs) + self.W2(hidden_expanded)
22
+ )
23
+ attention_logits = self.V(score)
24
+
25
+ if mask is not None:
26
+ attention_logits = attention_logits.masked_fill(
27
+ mask.unsqueeze(-1) == 0,
28
+ -1e9
29
+ )
30
+
31
+ attention_weights = F.softmax(attention_logits, dim=1).squeeze(2)
32
+ context = torch.bmm(
33
+ attention_weights.unsqueeze(1),
34
+ encoder_outputs
35
+ ).squeeze(1)
36
+
37
+ return context, attention_weights
model/decoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple, Optional
4
+
5
+ class Decoder(nn.Module):
6
+ def __init__(
7
+ self,
8
+ vocab_size: int,
9
+ embed_size: int,
10
+ hidden_size: int,
11
+ attention: nn.Module,
12
+ num_layers: int = 2,
13
+ dropout: float = 0.3
14
+ ):
15
+ super(Decoder, self).__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embed_size = embed_size
19
+ self.hidden_size = hidden_size
20
+ self.attention = attention
21
+
22
+ self.embedding = nn.Embedding(
23
+ num_embeddings=vocab_size,
24
+ embedding_dim=embed_size,
25
+ padding_idx=0
26
+ )
27
+
28
+ self.lstm = nn.LSTM(
29
+ input_size=embed_size + hidden_size * 2,
30
+ hidden_size=hidden_size,
31
+ num_layers=num_layers,
32
+ batch_first=True,
33
+ dropout=dropout if num_layers > 1 else 0
34
+ )
35
+
36
+ self.fc_out = nn.Linear(
37
+ hidden_size + hidden_size * 2 + embed_size,
38
+ vocab_size
39
+ )
40
+
41
+ self.dropout = nn.Dropout(dropout)
42
+
43
+ def forward(
44
+ self,
45
+ input_token: torch.Tensor,
46
+ decoder_hidden: torch.Tensor,
47
+ decoder_cell: torch.Tensor,
48
+ encoder_outputs: torch.Tensor,
49
+ mask: Optional[torch.Tensor] = None
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ embedded = self.dropout(self.embedding(input_token.unsqueeze(1)))
52
+ top_hidden = decoder_hidden[-1]
53
+ context, attention_weights = self.attention(
54
+ top_hidden, encoder_outputs, mask
55
+ )
56
+
57
+ lstm_input = torch.cat((embedded, context.unsqueeze(1)), dim=2)
58
+
59
+ output, (decoder_hidden, decoder_cell) = self.lstm(
60
+ lstm_input,
61
+ (decoder_hidden, decoder_cell)
62
+ )
63
+
64
+ output = output.squeeze(1)
65
+ embedded = embedded.squeeze(1)
66
+
67
+ output_context = torch.cat((output, context, embedded), dim=1)
68
+
69
+ prediction = self.fc_out(output_context)
70
+
71
+ return prediction, decoder_hidden, decoder_cell, attention_weights
model/encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Tuple
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(
7
+ self,
8
+ vocab_size: int,
9
+ embed_size: int,
10
+ hidden_size: int,
11
+ num_layers: int = 2,
12
+ dropout: float = 0.3
13
+ ):
14
+ super(Encoder, self).__init__()
15
+
16
+ self.vocab_size = vocab_size
17
+ self.embed_size = embed_size
18
+ self.hidden_size = hidden_size
19
+ self.num_layers = num_layers
20
+
21
+ self.embedding = nn.Embedding(
22
+ num_embeddings=vocab_size,
23
+ embedding_dim=embed_size,
24
+ padding_idx=0
25
+ )
26
+
27
+ self.lstm = nn.LSTM(
28
+ input_size=embed_size,
29
+ hidden_size=hidden_size,
30
+ num_layers=num_layers,
31
+ batch_first=True,
32
+ dropout=dropout if num_layers > 1 else 0,
33
+ bidirectional=True
34
+ )
35
+
36
+ self.dropout = nn.Dropout(dropout)
37
+
38
+ def forward(
39
+ self,
40
+ input_seq: torch.Tensor,
41
+ input_lengths: torch.Tensor
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
43
+ embedded = self.dropout(self.embedding(input_seq))
44
+ packed_embedded = nn.utils.rnn.pack_padded_sequence(
45
+ embedded,
46
+ input_lengths.cpu(),
47
+ batch_first=True,
48
+ enforce_sorted=False
49
+ )
50
+ packed_output, (hidden, cell) = self.lstm(packed_embedded)
51
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(
52
+ packed_output,
53
+ batch_first=True
54
+ )
55
+
56
+ return outputs, hidden, cell
model/generator.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Generator(nn.Module):
5
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, device: torch.device):
6
+ super(Generator, self).__init__()
7
+
8
+ self.encoder = encoder
9
+ self.decoder = decoder
10
+ self.device = device
11
+
12
+ assert encoder.hidden_size == decoder.hidden_size, \
13
+ "Encoder and decoder hidden sizes must match!"
14
+
15
+ self.hidden_projection = nn.Linear(
16
+ encoder.hidden_size * 2, decoder.hidden_size
17
+ )
18
+ self.cell_projection = nn.Linear(
19
+ encoder.hidden_size * 2, decoder.hidden_size
20
+ )
21
+
22
+ def _init_weights(self, module):
23
+ if isinstance(module, nn.Linear):
24
+ nn.init.normal_(module.weight.data, mean=0, std=0.01)
25
+ if module.bias is not None:
26
+ nn.init.constant_(module.bias.data, 0)
27
+ elif isinstance(module, nn.Embedding):
28
+ nn.init.normal_(module.weight.data, mean=0, std=0.01)
29
+ elif isinstance(module, nn.LSTM):
30
+ for name, param in module.named_parameters():
31
+ if 'weight' in name:
32
+ nn.init.orthogonal_(param.data)
33
+ elif 'bias' in name:
34
+ nn.init.constant_(param.data, 0)
35
+
36
+ def create_mask(self, input_seq: torch.Tensor) -> torch.Tensor:
37
+ return (input_seq != 0).float()
38
+
39
+ def forward(
40
+ self,
41
+ input_seq: torch.Tensor,
42
+ input_lengths: torch.Tensor,
43
+ target_seq: torch.Tensor,
44
+ teacher_forcing_ratio: float = 0.5
45
+ ) -> torch.Tensor:
46
+ batch_size = input_seq.shape[0]
47
+ target_len = target_seq.shape[1]
48
+ vocab_size = self.decoder.vocab_size
49
+
50
+ outputs = torch.zeros(target_len, batch_size, vocab_size).to(self.device)
51
+
52
+ encoder_outputs, hidden, cell = self.encoder(input_seq, input_lengths)
53
+
54
+ max_len = encoder_outputs.shape[1]
55
+ mask = torch.arange(max_len, device=self.device).unsqueeze(0) < input_lengths.unsqueeze(1)
56
+ mask = mask.float()
57
+
58
+ hidden = hidden.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
59
+ hidden = torch.cat((hidden[:, 0], hidden[:, 1]), dim=2)
60
+ hidden = self.hidden_projection(hidden)
61
+
62
+ cell = cell.view(self.encoder.num_layers, 2, batch_size, self.encoder.hidden_size)
63
+ cell = torch.cat((cell[:, 0], cell[:, 1]), dim=2)
64
+ cell = self.cell_projection(cell)
65
+
66
+ input_token = target_seq[:, 0]
67
+
68
+ for t in range(1, target_len):
69
+ output, hidden, cell, _ = self.decoder(
70
+ input_token, hidden, cell, encoder_outputs, mask
71
+ )
72
+ outputs[t] = output
73
+
74
+ teacher_force = torch.rand(1).item() < teacher_forcing_ratio
75
+ top1 = output.argmax(1)
76
+ input_token = target_seq[:, t] if teacher_force else top1
77
+
78
+ return outputs
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import Counter
3
+ from collections import defaultdict
4
+ from typing import Dict
5
+ from typing import List
6
+ from typing import Tuple
7
+
8
+
9
+ class Tokenizer:
10
+ def __init__(self, vocab_size: int = 1000):
11
+ self.special_tokens = ['<PAD>', '<UNK>', '<SOS>', '<EOS>']
12
+ self.char2idx: Dict[str, int] = {}
13
+ self.idx2char: Dict[int, str] = {}
14
+ self.vocab_size: int = 0
15
+ self.target_vocab_size: int = vocab_size
16
+ self.bpe_ranks: Dict[Tuple[str, str], int] = {}
17
+
18
+ for idx, token in enumerate(self.special_tokens):
19
+ self.char2idx[token] = idx
20
+ self.idx2char[idx] = token
21
+ self.vocab_size = len(self.special_tokens)
22
+
23
+ def _get_stats(self, words: Dict[Tuple[str, ...], int]) -> Counter:
24
+ pairs = Counter()
25
+ for word, freq in words.items():
26
+ for i in range(len(word) - 1):
27
+ pairs[(word[i], word[i + 1])] += freq
28
+ return pairs
29
+
30
+ def _merge_vocab(
31
+ self, pair: Tuple[str, str], words: Dict[Tuple[str, ...], int]
32
+ ) -> Dict[Tuple[str, ...], int]:
33
+ new_words = {}
34
+ replacement = "".join(pair)
35
+
36
+ for word in words:
37
+ new_word = []
38
+ i = 0
39
+ while i < len(word):
40
+ if (
41
+ i < len(word) - 1
42
+ and word[i] == pair[0]
43
+ and word[i + 1] == pair[1]
44
+ ):
45
+ new_word.append(replacement)
46
+ i += 2
47
+ else:
48
+ new_word.append(word[i])
49
+ i += 1
50
+ new_words[tuple(new_word)] = words[word]
51
+ return new_words
52
+
53
+ def build_vocab(self, texts: List[str]) -> None:
54
+ print(f"Building BPE vocabulary from {len(texts)} texts...")
55
+
56
+ vocab = set()
57
+ for text in texts:
58
+ vocab.update(text)
59
+
60
+ for char in sorted(vocab):
61
+ if char not in self.char2idx:
62
+ self.char2idx[char] = self.vocab_size
63
+ self.idx2char[self.vocab_size] = char
64
+ self.vocab_size += 1
65
+
66
+ print(
67
+ f"Initial character vocabulary: "
68
+ f"{self.vocab_size - len(self.special_tokens)} characters"
69
+ )
70
+
71
+ words = defaultdict(int)
72
+ for text in texts:
73
+ word = tuple(text)
74
+ words[word] += 1
75
+
76
+ num_merges = self.target_vocab_size - self.vocab_size
77
+ print(f"Learning {num_merges} BPE merges...")
78
+
79
+ for i in range(num_merges):
80
+ pairs = self._get_stats(words)
81
+ if not pairs:
82
+ break
83
+
84
+ best_pair = max(pairs, key=pairs.get)
85
+ words = self._merge_vocab(best_pair, words)
86
+
87
+ new_token = ''.join(best_pair)
88
+ if new_token not in self.char2idx:
89
+ self.char2idx[new_token] = self.vocab_size
90
+ self.idx2char[self.vocab_size] = new_token
91
+ self.vocab_size += 1
92
+
93
+ self.bpe_ranks[best_pair] = i
94
+
95
+ if (i + 1) % 100 == 0:
96
+ print(
97
+ f" Learned {i + 1} merges, "
98
+ f"vocab size: {self.vocab_size}"
99
+ )
100
+
101
+ print(f"BPE Vocabulary built! Total tokens: {self.vocab_size}")
102
+ print(f" - Special tokens: {len(self.special_tokens)}")
103
+ print(f" - Base characters: {len(vocab)}")
104
+ print(f" - BPE subwords: {len(self.bpe_ranks)}")
105
+ print(f" - Sample subwords: {list(self.bpe_ranks.keys())[:5]}")
106
+
107
+ def _tokenize(self, text: str) -> List[str]:
108
+ if not text:
109
+ return []
110
+
111
+ word = tuple(text)
112
+
113
+ while len(word) > 1:
114
+ pairs = [(word[i], word[i + 1]) for i in range(len(word) - 1)]
115
+ valid_pairs = [p for p in pairs if p in self.bpe_ranks]
116
+
117
+ if not valid_pairs:
118
+ break
119
+
120
+ bigram = min(valid_pairs, key=lambda p: self.bpe_ranks[p])
121
+
122
+ new_word = []
123
+ i = 0
124
+ while i < len(word):
125
+ if (
126
+ i < len(word) - 1
127
+ and word[i] == bigram[0]
128
+ and word[i + 1] == bigram[1]
129
+ ):
130
+ new_word.append("".join(bigram))
131
+ i += 2
132
+ else:
133
+ new_word.append(word[i])
134
+ i += 1
135
+ word = tuple(new_word)
136
+
137
+ return list(word)
138
+
139
+ def add_token(self, token: str) -> None:
140
+ if token not in self.char2idx:
141
+ idx = self.vocab_size
142
+ self.char2idx[token] = idx
143
+ self.idx2char[idx] = token
144
+ self.vocab_size += 1
145
+
146
+ def encode(
147
+ self, text: str, max_length: int, add_special_tokens: bool = True
148
+ ) -> List[int]:
149
+ tokens = self._tokenize(text)
150
+
151
+ indices = []
152
+
153
+ if add_special_tokens:
154
+ indices.append(self.char2idx['<SOS>'])
155
+
156
+ for token in tokens[:max_length - (2 if add_special_tokens else 0)]:
157
+ indices.append(self.char2idx.get(token, self.char2idx['<UNK>']))
158
+
159
+ if add_special_tokens:
160
+ indices.append(self.char2idx['<EOS>'])
161
+
162
+ while len(indices) < max_length:
163
+ indices.append(self.char2idx['<PAD>'])
164
+
165
+ return indices
166
+
167
+ def decode(self, indices: List[int]) -> str:
168
+ chars = []
169
+ for idx in indices:
170
+ token = self.idx2char.get(idx, '<UNK>')
171
+ if token == '<EOS>':
172
+ break
173
+ if token not in ['<PAD>', '<SOS>', '<UNK>']:
174
+ chars.append(token)
175
+ return ''.join(chars)
176
+
177
+ def save(self, filepath: str) -> None:
178
+ state = {
179
+ "char2idx": self.char2idx,
180
+ "special_tokens": self.special_tokens,
181
+ "vocab_size": self.vocab_size,
182
+ "target_vocab_size": self.target_vocab_size,
183
+ "bpe_ranks": {
184
+ f"{k[0]}_{k[1]}": v for k, v in self.bpe_ranks.items()
185
+ },
186
+ }
187
+ with open(filepath, "w") as f:
188
+ json.dump(state, f, indent=2)
189
+ print(f"BPE Tokenizer saved to {filepath}")
190
+
191
+ def load(self, filepath: str) -> "Tokenizer":
192
+ with open(filepath, "r") as f:
193
+ state = json.load(f)
194
+
195
+ self.char2idx = state["char2idx"]
196
+ self.special_tokens = state["special_tokens"]
197
+ self.vocab_size = state["vocab_size"]
198
+ self.target_vocab_size = state.get("target_vocab_size", 1000)
199
+ self.idx2char = {v: k for k, v in self.char2idx.items()}
200
+
201
+ if "bpe_ranks" in state:
202
+ self.bpe_ranks = {}
203
+ for key, value in state["bpe_ranks"].items():
204
+ parts = key.split("_", 1)
205
+ if len(parts) == 2:
206
+ self.bpe_ranks[(parts[0], parts[1])] = value
207
+
208
+ print(f"BPE Tokenizer loaded from {filepath}")
209
+ print(f" - Vocab size: {self.vocab_size}")
210
+ print(f" - BPE merges: {len(self.bpe_ranks)}")
211
+
212
+ return self