{"cells":[{"cell_type":"code","execution_count":5,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WeogYTNaDj3r","executionInfo":{"status":"ok","timestamp":1770451365426,"user_tz":-420,"elapsed":8232,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"outputId":"31f7bff5-87fd-4226-f664-6caa3742a41c"},"outputs":[{"output_type":"stream","name":"stdout","text":["Running in Google Colab\n","Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.12/dist-packages (0.2.1)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.20.3)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2)\n","Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0)\n","Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2)\n","Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\n","Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.2)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.6.0)\n","Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\n","Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0)\n","Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (1.3.7)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (26.0)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.3)\n","Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.13.3)\n","Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.2.0)\n","Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (0.28.1)\n","Requirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.5.4)\n","Requirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (0.21.1)\n","Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0)\n","Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.4)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.11)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2026.1.4)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.3)\n","Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1)\n","Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.4.0)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.8.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.7.1)\n","Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.4.1)\n","Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.22.0)\n","Requirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (4.12.1)\n","Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (1.0.9)\n","Requirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface-hub>=0.24.0->datasets) (0.16.0)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n","Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->huggingface-hub>=0.24.0->datasets) (8.3.1)\n","Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"]}],"source":["# Google Colab Setup\n","try:\n"," import google.colab\n"," IN_COLAB = True\n"," print(\"Running in Google Colab\")\n"," !pip install datasets sentencepiece\n"," from google.colab import drive\n"," drive.mount('/content/drive')\n"," # Optional: Change to your project directory if needed\n"," # import os\n"," # os.chdir('/content/drive/MyDrive/NLP/Project_A3/A3_Burmese_English_Puffer')\n","except ImportError:\n"," IN_COLAB = False\n"," print(\"Running Locally\")"],"id":"WeogYTNaDj3r"},{"cell_type":"markdown","metadata":{"id":"o60RyQ1GDj3t"},"source":["# German-English Machine Translation (A3 Project)\n","\n","**Student**: Htut Ko Ko \n","**Course**: Natural Language Understanding \n","**Task**: German (de) <-> English (en) Translation using Transformer\n","\n","## Project Overview\n","This notebook implements a Neural Machine Translation system using a **Transformer** architecture.\n","We use the **Opus-100** dataset for German-English parallel data.\n","We use **SentencePiece** for subword tokenization.\n"],"id":"o60RyQ1GDj3t"},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mWWrcN7TDj3v","executionInfo":{"status":"ok","timestamp":1770451365477,"user_tz":-420,"elapsed":34,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"outputId":"505529e2-c5bc-4618-d60e-ebf3205e02a6"},"outputs":[{"output_type":"stream","name":"stdout","text":["Using device: cuda\n"]}],"source":["import os\n","import math\n","import time\n","import random\n","import numpy as np\n","import pandas as pd\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torch.utils.data import Dataset, DataLoader\n","from torch.nn.utils.rnn import pad_sequence\n","from datasets import load_dataset\n","import sentencepiece as spm\n","\n","# Check for GPU\n","device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))\n","print(f\"Using device: {device}\")\n","\n","SEED = 1234\n","random.seed(SEED)\n","np.random.seed(SEED)\n","torch.manual_seed(SEED)\n","torch.cuda.manual_seed(SEED)\n","torch.backends.cudnn.deterministic = True"],"id":"mWWrcN7TDj3v"},{"cell_type":"markdown","metadata":{"id":"u5XpwPylDj3w"},"source":["## 2. Data Loading (Opus-100)\n","Loading German-English pairs from Opus-100."],"id":"u5XpwPylDj3w"},{"cell_type":"code","execution_count":7,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sBw80f5rDj3w","executionInfo":{"status":"ok","timestamp":1770451394573,"user_tz":-420,"elapsed":29093,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"outputId":"6f313d68-5f9c-4c07-82d1-6fc2b4726e37"},"outputs":[{"output_type":"stream","name":"stdout","text":["Loading Opus-100 Dataset (German-English)...\n","Loaded 1004000 sentences from Opus-100 dataset.\n","Subsampled to 50,000 examples for efficiency.\n","Extracted 50000 German-English pairs.\n"]}],"source":["print(\"Loading Opus-100 Dataset (German-English)...\")\n","\n","data = []\n","try:\n"," # Opus-100 has 'de-en' or 'en-de'\n"," dataset = load_dataset(\"opus100\", \"de-en\", split=\"train+validation+test\")\n"," print(f\"Loaded {len(dataset)} sentences from Opus-100 dataset.\")\n","\n"," for item in dataset:\n"," if 'translation' in item:\n"," # 'de' is the language code for German\n"," if 'de' in item['translation'] and 'en' in item['translation']:\n"," data.append({\n"," 'de': item['translation']['de'],\n"," 'en': item['translation']['en']\n"," })\n","\n"," # Limit to manageable size for this project\n"," if len(data) > 50000:\n"," import random\n"," random.shuffle(data)\n"," data = data[:50000]\n"," print(\"Subsampled to 50,000 examples for efficiency.\")\n","\n"," print(f\"Extracted {len(data)} German-English pairs.\")\n","except Exception as e:\n"," print(f\"Error loading from HF: {e}\")"],"id":"sBw80f5rDj3w"},{"cell_type":"code","execution_count":8,"metadata":{"id":"ETHqEVLgDj3w","executionInfo":{"status":"ok","timestamp":1770451394738,"user_tz":-420,"elapsed":172,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"502491ee-6707-4b4b-ab74-fefadc5823b6"},"outputs":[{"output_type":"stream","name":"stdout","text":[" de \\\n","0 Offenbar werde ich verdächtigt. \n","1 Tielt +17°C \n","2 Wie geht's dir? \n","3 Zu ihm verhalten sich die Farben (guasch, temp... \n","4 -Was? \n","\n"," en \n","0 Apparently, I'm a suspect. \n","1 Tucupido +28°C \n","2 How are you? \n","3 Paints concern them (gouache, distemper, poliv... \n","4 You can't mean it! \n"]}],"source":["df = pd.DataFrame(data)\n","print(df.head())\n","\n","df = df.dropna(subset=['de', 'en'])\n","df['de'] = df['de'].astype(str)\n","df['en'] = df['en'].astype(str)\n","df = df[df['de'].str.strip() != '']\n","df = df[df['en'].str.strip() != '']"],"id":"ETHqEVLgDj3w"},{"cell_type":"markdown","metadata":{"id":"2cztqOQUDj3x"},"source":["## 3. Tokenization"],"id":"2cztqOQUDj3x"},{"cell_type":"code","execution_count":9,"metadata":{"id":"4YQnkzD_Dj3x","executionInfo":{"status":"ok","timestamp":1770451406984,"user_tz":-420,"elapsed":12244,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"e444ef57-9d4f-43e6-b6df-cbb535d8401e"},"outputs":[{"output_type":"stream","name":"stdout","text":["Training German Tokenizer...\n","Training English Tokenizer (for German pair)...\n"]}],"source":["# Save texts to files\n","with open('train_de.txt', 'w', encoding='utf-8') as f:\n"," for line in df['de']: f.write(line + '\\n')\n","\n","with open('train_en_de.txt', 'w', encoding='utf-8') as f:\n"," for line in df['en']: f.write(line + '\\n')\n","\n","# Train SentencePiece models\n","vocab_size = 8000\n","model_type = 'bpe'\n","\n","print(\"Training German Tokenizer...\")\n","spm.SentencePieceTrainer.train(\n"," input='train_de.txt',\n"," model_prefix='spm_de',\n"," vocab_size=vocab_size,\n"," model_type=model_type,\n"," pad_id=0, bos_id=1, eos_id=2, unk_id=3\n",")\n","\n","print(\"Training English Tokenizer (for German pair)...\")\n","spm.SentencePieceTrainer.train(\n"," input='train_en_de.txt',\n"," model_prefix='spm_en_de',\n"," vocab_size=vocab_size,\n"," model_type=model_type,\n"," pad_id=0, bos_id=1, eos_id=2, unk_id=3\n",")\n","\n","sp_src = spm.SentencePieceProcessor(model_file='spm_de.model')\n","sp_trg = spm.SentencePieceProcessor(model_file='spm_en_de.model')"],"id":"4YQnkzD_Dj3x"},{"cell_type":"markdown","metadata":{"id":"VxbMUHVeDj3x"},"source":["## 4. Dataset & Model"],"id":"VxbMUHVeDj3x"},{"cell_type":"code","execution_count":10,"metadata":{"id":"hVhDYMytDj3x","executionInfo":{"status":"ok","timestamp":1770451406999,"user_tz":-420,"elapsed":5,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}}},"outputs":[],"source":["class TranslationDataset(Dataset):\n"," def __init__(self, df, sp_src, sp_trg):\n"," self.data = df\n"," self.sp_src = sp_src\n"," self.sp_trg = sp_trg\n","\n"," def __len__(self):\n"," return len(self.data)\n","\n"," def __getitem__(self, idx):\n"," src_text = self.data.iloc[idx]['de']\n"," trg_text = self.data.iloc[idx]['en']\n"," src_ids = [self.sp_src.bos_id()] + self.sp_src.encode(src_text, out_type=int) + [self.sp_src.eos_id()]\n"," trg_ids = [self.sp_trg.bos_id()] + self.sp_trg.encode(trg_text, out_type=int) + [self.sp_trg.eos_id()]\n"," return torch.tensor(src_ids), torch.tensor(trg_ids)\n","\n","def collate_fn(batch):\n"," src_batch, trg_batch = [], []\n"," for src, trg in batch:\n"," src_batch.append(src)\n"," trg_batch.append(trg)\n"," src_pad = pad_sequence(src_batch, batch_first=True, padding_value=0)\n"," trg_pad = pad_sequence(trg_batch, batch_first=True, padding_value=0)\n"," return src_pad, trg_pad\n","\n","train_dataset = TranslationDataset(df, sp_src, sp_trg)\n","train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)"],"id":"hVhDYMytDj3x"},{"cell_type":"code","execution_count":11,"metadata":{"id":"f5fLfvEWDj3y","executionInfo":{"status":"ok","timestamp":1770451407006,"user_tz":-420,"elapsed":4,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}}},"outputs":[],"source":["class PositionalEncoding(nn.Module):\n"," def __init__(self, d_model, dropout=0.1, max_len=5000):\n"," super(PositionalEncoding, self).__init__()\n"," self.dropout = nn.Dropout(p=dropout)\n"," pe = torch.zeros(max_len, d_model)\n"," position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n"," div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n"," pe[:, 0::2] = torch.sin(position * div_term)\n"," pe[:, 1::2] = torch.cos(position * div_term)\n"," self.register_buffer('pe', pe)\n","\n"," def forward(self, x):\n"," x = x + self.pe[:x.size(1), :]\n"," return self.dropout(x)\n","\n","class TransformerModel(nn.Module):\n"," def __init__(self, src_vocab_size, trg_vocab_size,\n"," d_model=256, nhead=4, num_encoder_layers=2,\n"," num_decoder_layers=2, dim_feedforward=512, dropout=0.1, pad_idx=0):\n"," super(TransformerModel, self).__init__()\n"," self.d_model = d_model\n"," self.pad_idx = pad_idx\n"," self.src_embedding = nn.Embedding(src_vocab_size, d_model)\n"," self.trg_embedding = nn.Embedding(trg_vocab_size, d_model)\n"," self.pos_encoder = PositionalEncoding(d_model, dropout)\n"," self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)\n"," self.fc_out = nn.Linear(d_model, trg_vocab_size)\n","\n"," def forward(self, src, trg):\n"," src_key_padding_mask = (src == self.pad_idx)\n"," trg_mask = self.transformer.generate_square_subsequent_mask(trg.size(1)).to(src.device)\n"," src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))\n"," trg_emb = self.pos_encoder(self.trg_embedding(trg) * math.sqrt(self.d_model))\n"," output = self.transformer(src=src_emb, tgt=trg_emb, tgt_mask=trg_mask, src_key_padding_mask=src_key_padding_mask)\n"," return self.fc_out(output)"],"id":"f5fLfvEWDj3y"},{"cell_type":"code","execution_count":12,"metadata":{"id":"01why68ZDj3z","executionInfo":{"status":"ok","timestamp":1770451876745,"user_tz":-420,"elapsed":469736,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"ecf12f45-ce36-4f55-9ef0-e7f42d6326f1"},"outputs":[{"output_type":"stream","name":"stdout","text":["Starting Training...\n","Step 0, Loss: 9.122\n","Step 100, Loss: 6.802\n","Step 200, Loss: 6.478\n","Step 300, Loss: 6.420\n","Step 400, Loss: 6.104\n","Step 500, Loss: 6.265\n","Step 600, Loss: 5.877\n","Step 700, Loss: 5.790\n","Epoch 1 Loss: 6.320\n","Step 0, Loss: 5.578\n","Step 100, Loss: 5.876\n","Step 200, Loss: 5.782\n","Step 300, Loss: 5.453\n","Step 400, Loss: 5.472\n","Step 500, Loss: 5.311\n","Step 600, Loss: 5.294\n","Step 700, Loss: 5.511\n","Epoch 2 Loss: 5.540\n","Step 0, Loss: 5.304\n","Step 100, Loss: 4.828\n","Step 200, Loss: 5.449\n","Step 300, Loss: 5.142\n","Step 400, Loss: 4.986\n","Step 500, Loss: 5.251\n","Step 600, Loss: 5.048\n","Step 700, Loss: 5.164\n","Epoch 3 Loss: 5.111\n","Step 0, Loss: 4.924\n","Step 100, Loss: 4.869\n","Step 200, Loss: 4.970\n","Step 300, Loss: 4.884\n","Step 400, Loss: 4.627\n","Step 500, Loss: 4.850\n","Step 600, Loss: 4.678\n","Step 700, Loss: 4.876\n","Epoch 4 Loss: 4.832\n","Step 0, Loss: 4.758\n","Step 100, Loss: 4.387\n","Step 200, Loss: 4.616\n","Step 300, Loss: 4.687\n","Step 400, Loss: 4.621\n","Step 500, Loss: 4.487\n","Step 600, Loss: 4.673\n","Step 700, Loss: 4.743\n","Epoch 5 Loss: 4.632\n","Step 0, Loss: 4.118\n","Step 100, Loss: 4.295\n","Step 200, Loss: 4.074\n","Step 300, Loss: 4.624\n","Step 400, Loss: 4.367\n","Step 500, Loss: 4.572\n","Step 600, Loss: 4.676\n","Step 700, Loss: 4.437\n","Epoch 6 Loss: 4.476\n","Step 0, Loss: 4.247\n","Step 100, Loss: 4.121\n","Step 200, Loss: 4.197\n","Step 300, Loss: 4.304\n","Step 400, Loss: 4.441\n","Step 500, Loss: 4.371\n","Step 600, Loss: 4.300\n","Step 700, Loss: 4.265\n","Epoch 7 Loss: 4.346\n","Step 0, Loss: 4.091\n","Step 100, Loss: 4.079\n","Step 200, Loss: 4.234\n","Step 300, Loss: 4.174\n","Step 400, Loss: 4.122\n","Step 500, Loss: 4.436\n","Step 600, Loss: 4.196\n","Step 700, Loss: 4.381\n","Epoch 8 Loss: 4.236\n","Step 0, Loss: 4.214\n","Step 100, Loss: 4.318\n","Step 200, Loss: 4.281\n","Step 300, Loss: 4.474\n","Step 400, Loss: 4.199\n","Step 500, Loss: 4.254\n","Step 600, Loss: 4.127\n","Step 700, Loss: 4.140\n","Epoch 9 Loss: 4.137\n","Step 0, Loss: 3.667\n","Step 100, Loss: 4.102\n","Step 200, Loss: 3.962\n","Step 300, Loss: 4.091\n","Step 400, Loss: 3.765\n","Step 500, Loss: 4.123\n","Step 600, Loss: 4.305\n","Step 700, Loss: 4.151\n","Epoch 10 Loss: 4.051\n"]}],"source":["model = TransformerModel(vocab_size, vocab_size).to(device)\n","optimizer = optim.Adam(model.parameters(), lr=0.0005)\n","criterion = nn.CrossEntropyLoss(ignore_index=0)\n","\n","print(\"Starting Training...\")\n","for epoch in range(10): # 10 Epochs for demo (Opus-100 is large)\n"," model.train()\n"," epoch_loss = 0\n"," for i, (src, trg) in enumerate(train_loader):\n"," src, trg = src.to(device), trg.to(device)\n"," optimizer.zero_grad()\n"," output = model(src, trg[:, :-1])\n"," output = output.contiguous().view(-1, output.shape[-1])\n"," trg = trg[:, 1:].contiguous().view(-1)\n"," loss = criterion(output, trg)\n"," loss.backward()\n"," optimizer.step()\n"," epoch_loss += loss.item()\n"," if i % 100 == 0: print(f\"Step {i}, Loss: {loss.item():.3f}\")\n"," print(f\"Epoch {epoch+1} Loss: {epoch_loss/len(train_loader):.3f}\")\n","\n"," # Save\n"," torch.save(model.state_dict(), 'transformer_model_de.pt')"],"id":"01why68ZDj3z"},{"cell_type":"code","execution_count":13,"metadata":{"id":"NRZITD1eDj3z","executionInfo":{"status":"ok","timestamp":1770451876783,"user_tz":-420,"elapsed":27,"user":{"displayName":"Htut Ko Ko","userId":"13068088192988605156"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"ca82b3a4-6b08-4e16-a6c7-e755022ff738"},"outputs":[{"output_type":"stream","name":"stdout","text":["Models copied to app/models/\n"]}],"source":["# Copy to app\n","import shutil\n","os.makedirs('app/models', exist_ok=True)\n","shutil.copy('transformer_model_de.pt', 'app/models/transformer_model_de.pt')\n","shutil.copy('spm_de.model', 'app/models/spm_de.model')\n","shutil.copy('spm_en_de.model', 'app/models/spm_en_de.model')\n","print(\"Models copied to app/models/\")"],"id":"NRZITD1eDj3z"}],"metadata":{"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.10"},"colab":{"provenance":[],"gpuType":"L4"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":5}