shadowsilence commited on
Commit
2ba8ae0
·
verified ·
1 Parent(s): cad2609

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .ipynb_checkpoints/
4
+ .DS_Store
5
+ *.env
6
+ venv/
7
+ env/
8
+ # models/*.pt
9
+ # If using large models, add them to LFS instead
A4_BERT.ipynb ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "TB4CfcNaQZFN"
7
+ },
8
+ "source": [
9
+ "# A4: BERT Pre-training from Scratch\n",
10
+ "## Student Information\n",
11
+ "**Name:** HTUT KO KO \n",
12
+ "**ID:** st126010 \n",
13
+ "\n",
14
+ "## Task 1: BERT implementation\n",
15
+ "In this notebook, I implement BERT from scratch and pre-train it on the WikiText-103 dataset.\n",
16
+ "**Optimization Note:** To achieve lower loss on this small-scale demonstration, I use a smaller subset of data, a smaller vocabulary, and run for more adequate epochs."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "name": "stderr",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "/opt/homebrew/Caskroom/miniforge/base/envs/ai_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
29
+ " from .autonotebook import tqdm as notebook_tqdm\n"
30
+ ]
31
+ },
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "Using device: mps\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "import torch\n",
42
+ "import torch.nn as nn\n",
43
+ "import torch.optim as optim\n",
44
+ "import numpy as np\n",
45
+ "import random\n",
46
+ "from random import randrange, shuffle, randint\n",
47
+ "from datasets import load_dataset\n",
48
+ "from transformers import BertTokenizer\n",
49
+ "from torch.utils.data import DataLoader, Dataset\n",
50
+ "import re\n",
51
+ "from collections import Counter\n",
52
+ "\n",
53
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
54
+ "print(f\"Using device: {device}\")"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "### 1. Data Loading & Preprocessing\n",
62
+ "I will use a smaller vocabulary size to make the model convergence easier for the assignment."
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 2,
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stdout",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Vocab Size: 5004\n"
75
+ ]
76
+ }
77
+ ],
78
+ "source": [
79
+ "# 1. Load Data\n",
80
+ "dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')\n",
81
+ "subset_size = 5000 # Reduce size for faster iteration and better convergence on small model\n",
82
+ "dataset = dataset.select(range(subset_size))\n",
83
+ "raw_text_data = [line for line in dataset['text'] if len(line) > 20]\n",
84
+ "\n",
85
+ "# 2. Build Custom Vocabulary (Crucial for small data)\n",
86
+ "# Instead of using 30k BERT vocab, we build one from our data\n",
87
+ "tokens = [word.lower() for sent in raw_text_data for word in sent.split()]\n",
88
+ "vocab_counter = Counter(tokens)\n",
89
+ "vocab = sorted(vocab_counter, key=vocab_counter.get, reverse=True)[:5000] # Top 5k words\n",
90
+ "word2id = {w: i+4 for i, w in enumerate(vocab)}\n",
91
+ "word2id['[PAD]'] = 0\n",
92
+ "word2id['[CLS]'] = 1\n",
93
+ "word2id['[SEP]'] = 2\n",
94
+ "word2id['[MASK]'] = 3\n",
95
+ "id2word = {i: w for w, i in word2id.items()}\n",
96
+ "vocab_size = len(word2id)\n",
97
+ "print(f\"Vocab Size: {vocab_size}\")\n",
98
+ "\n",
99
+ "token_list = []\n",
100
+ "for sentence in raw_text_data:\n",
101
+ " # Simple whitespace tokenization for this demo\n",
102
+ " seq = [word2id.get(w.lower(), 0) for w in sentence.split()] \n",
103
+ " if len(seq) > 0:\n",
104
+ " token_list.append(seq)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "markdown",
109
+ "metadata": {},
110
+ "source": [
111
+ "### 2. BERT Hyperparameters & Data Loader"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 3,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "max_len = 128\n",
121
+ "batch_size = 16 # Small batch size\n",
122
+ "max_mask = 20\n",
123
+ "n_layers = 2 # Shallower model for easier training\n",
124
+ "n_heads = 4\n",
125
+ "d_model = 256\n",
126
+ "d_ff = 256 * 4\n",
127
+ "d_k = d_v = 64\n",
128
+ "n_segments = 2\n",
129
+ "\n",
130
+ "def make_batch():\n",
131
+ " batch = []\n",
132
+ " positive = negative = 0\n",
133
+ " while positive != batch_size / 2 or negative != batch_size / 2:\n",
134
+ " tokens_a_index, tokens_b_index = randrange(len(token_list)), randrange(len(token_list))\n",
135
+ " tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]\n",
136
+ "\n",
137
+ " input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]\n",
138
+ " segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n",
139
+ "\n",
140
+ " input_ids = input_ids[:max_len]\n",
141
+ " segment_ids = segment_ids[:max_len]\n",
142
+ "\n",
143
+ " n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))\n",
144
+ " candidates_masked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]\n",
145
+ " shuffle(candidates_masked_pos)\n",
146
+ " masked_tokens, masked_pos = [], []\n",
147
+ " for pos in candidates_masked_pos[:n_pred]:\n",
148
+ " masked_pos.append(pos)\n",
149
+ " masked_tokens.append(input_ids[pos])\n",
150
+ " if random.random() < 0.1:\n",
151
+ " input_ids[pos] = randint(0, vocab_size - 1)\n",
152
+ " elif random.random() < 0.8:\n",
153
+ " input_ids[pos] = word2id['[MASK]']\n",
154
+ "\n",
155
+ " n_pad = max_len - len(input_ids)\n",
156
+ " input_ids.extend([0] * n_pad)\n",
157
+ " segment_ids.extend([0] * n_pad)\n",
158
+ "\n",
159
+ " if max_mask > n_pred:\n",
160
+ " n_pad = max_mask - n_pred\n",
161
+ " masked_tokens.extend([0] * n_pad)\n",
162
+ " masked_pos.extend([0] * n_pad)\n",
163
+ "\n",
164
+ " if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:\n",
165
+ " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])\n",
166
+ " positive += 1\n",
167
+ " elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:\n",
168
+ " batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])\n",
169
+ " negative += 1\n",
170
+ " return batch"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "### 3. BERT Model Architecture"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 4,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "class Embedding(nn.Module):\n",
187
+ " def __init__(self):\n",
188
+ " super(Embedding, self).__init__()\n",
189
+ " self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
190
+ " self.pos_embed = nn.Embedding(max_len, d_model)\n",
191
+ " self.seg_embed = nn.Embedding(n_segments, d_model)\n",
192
+ " self.norm = nn.LayerNorm(d_model)\n",
193
+ "\n",
194
+ " def forward(self, x, seg):\n",
195
+ " seq_len = x.size(1)\n",
196
+ " pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
197
+ " pos = pos.unsqueeze(0).expand_as(x)\n",
198
+ " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
199
+ " return self.norm(embedding)\n",
200
+ "\n",
201
+ "class MultiHeadAttention(nn.Module):\n",
202
+ " def __init__(self):\n",
203
+ " super(MultiHeadAttention, self).__init__()\n",
204
+ " self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
205
+ " self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
206
+ " self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
207
+ " self.linear = nn.Linear(n_heads * d_v, d_model)\n",
208
+ " self.layer_norm = nn.LayerNorm(d_model)\n",
209
+ "\n",
210
+ " def forward(self, Q, K, V, attn_mask):\n",
211
+ " batch_size = Q.size(0)\n",
212
+ " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
213
+ " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
214
+ " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
215
+ " \n",
216
+ " attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
217
+ " \n",
218
+ " scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
219
+ " scores.masked_fill_(attn_mask, -1e9)\n",
220
+ " attn = nn.Softmax(dim=-1)(scores)\n",
221
+ " context = torch.matmul(attn, v_s)\n",
222
+ " context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
223
+ " output = self.linear(context)\n",
224
+ " return self.layer_norm(output + Q), attn\n",
225
+ "\n",
226
+ "class PoswiseFeedForwardNet(nn.Module):\n",
227
+ " def __init__(self):\n",
228
+ " super(PoswiseFeedForwardNet, self).__init__()\n",
229
+ " self.fc1 = nn.Linear(d_model, d_ff)\n",
230
+ " self.fc2 = nn.Linear(d_ff, d_model)\n",
231
+ "\n",
232
+ " def forward(self, x):\n",
233
+ " return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
234
+ "\n",
235
+ "class EncoderLayer(nn.Module):\n",
236
+ " def __init__(self):\n",
237
+ " super(EncoderLayer, self).__init__()\n",
238
+ " self.enc_self_attn = MultiHeadAttention()\n",
239
+ " self.pos_ffn = PoswiseFeedForwardNet()\n",
240
+ "\n",
241
+ " def forward(self, enc_inputs, enc_self_attn_mask):\n",
242
+ " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
243
+ " enc_outputs = self.pos_ffn(enc_outputs)\n",
244
+ " return enc_outputs, attn\n",
245
+ "\n",
246
+ "def get_attn_pad_mask(seq_q, seq_k):\n",
247
+ " batch_size, len_q = seq_q.size()\n",
248
+ " batch_size, len_k = seq_k.size()\n",
249
+ " pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
250
+ " return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
251
+ "\n",
252
+ "class BERT(nn.Module):\n",
253
+ " def __init__(self):\n",
254
+ " super(BERT, self).__init__()\n",
255
+ " self.embedding = Embedding()\n",
256
+ " self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
257
+ " self.fc = nn.Linear(d_model, d_model)\n",
258
+ " self.activ = nn.Tanh()\n",
259
+ " self.linear = nn.Linear(d_model, d_model)\n",
260
+ " self.norm = nn.LayerNorm(d_model)\n",
261
+ " self.classifier = nn.Linear(d_model, 2)\n",
262
+ " embed_weight = self.embedding.tok_embed.weight\n",
263
+ " n_vocab, n_dim = embed_weight.size()\n",
264
+ " self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
265
+ " self.decoder.weight = embed_weight\n",
266
+ " self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
267
+ "\n",
268
+ " def forward(self, input_ids, segment_ids, masked_pos):\n",
269
+ " output = self.embedding(input_ids, segment_ids)\n",
270
+ " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
271
+ " for layer in self.layers:\n",
272
+ " output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
273
+ " \n",
274
+ " h_pooled = self.activ(self.fc(output[:, 0]))\n",
275
+ " logits_nsp = self.classifier(h_pooled)\n",
276
+ " \n",
277
+ " masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)\n",
278
+ " h_masked = torch.gather(output, 1, masked_pos)\n",
279
+ " h_masked = self.norm(self.activ(self.linear(h_masked)))\n",
280
+ " logits_lm = self.decoder(h_masked) + self.decoder_bias\n",
281
+ "\n",
282
+ " return logits_lm, logits_nsp, output"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "metadata": {},
288
+ "source": [
289
+ "### 4. Training Loop\n",
290
+ "I train for 2000 epochs to ensure convergence."
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 6,
296
+ "metadata": {},
297
+ "outputs": [
298
+ {
299
+ "name": "stdout",
300
+ "output_type": "stream",
301
+ "text": [
302
+ "Starting Training...\n",
303
+ "Epoch: 0100 | loss = 13.480813\n",
304
+ "Epoch: 0200 | loss = 10.269491\n",
305
+ "Epoch: 0300 | loss = 8.912387\n",
306
+ "Epoch: 0400 | loss = 7.265475\n",
307
+ "Epoch: 0500 | loss = 6.765235\n",
308
+ "Epoch: 0600 | loss = 7.150949\n",
309
+ "Epoch: 0700 | loss = 6.182394\n",
310
+ "Epoch: 0800 | loss = 6.075039\n",
311
+ "Epoch: 0900 | loss = 6.766500\n",
312
+ "Epoch: 1000 | loss = 6.545547\n",
313
+ "Epoch: 1100 | loss = 6.488539\n",
314
+ "Epoch: 1200 | loss = 6.223000\n",
315
+ "Epoch: 1300 | loss = 5.912578\n",
316
+ "Epoch: 1400 | loss = 6.125433\n",
317
+ "Epoch: 1500 | loss = 6.077301\n",
318
+ "Epoch: 1600 | loss = 6.500366\n",
319
+ "Epoch: 1700 | loss = 6.560534\n",
320
+ "Epoch: 1800 | loss = 6.262241\n",
321
+ "Epoch: 1900 | loss = 5.871750\n",
322
+ "Epoch: 2000 | loss = 6.158124\n",
323
+ "Training Complete. Model Saved.\n"
324
+ ]
325
+ }
326
+ ],
327
+ "source": [
328
+ "model = BERT().to(device)\n",
329
+ "criterion = nn.CrossEntropyLoss(ignore_index=0)\n",
330
+ "criterion_nsp = nn.CrossEntropyLoss()\n",
331
+ "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
332
+ "\n",
333
+ "print(\"Starting Training...\")\n",
334
+ "for epoch in range(2000):\n",
335
+ " batch = make_batch()\n",
336
+ " input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))\n",
337
+ " input_ids, segment_ids, masked_tokens, masked_pos, isNext = input_ids.to(device), segment_ids.to(device), masked_tokens.to(device), masked_pos.to(device), isNext.to(device)\n",
338
+ "\n",
339
+ " optimizer.zero_grad()\n",
340
+ " logits_lm, logits_nsp, _ = model(input_ids, segment_ids, masked_pos)\n",
341
+ " loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens).mean()\n",
342
+ " loss_nsp = criterion_nsp(logits_nsp, isNext)\n",
343
+ " loss = loss_lm + loss_nsp\n",
344
+ " loss.backward()\n",
345
+ " optimizer.step()\n",
346
+ " \n",
347
+ " if (epoch + 1) % 100 == 0:\n",
348
+ " print(f'Epoch: {epoch + 1:04d} | loss = {loss.item():.6f}')\n",
349
+ "\n",
350
+ "torch.save(model.state_dict(), './models/bert_trained.pt')\n",
351
+ "print(\"Training Complete. Model Saved.\")"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": []
360
+ }
361
+ ],
362
+ "metadata": {
363
+ "kernelspec": {
364
+ "display_name": "ai_env",
365
+ "language": "python",
366
+ "name": "python3"
367
+ },
368
+ "language_info": {
369
+ "codemirror_mode": {
370
+ "name": "ipython",
371
+ "version": 3
372
+ },
373
+ "file_extension": ".py",
374
+ "mimetype": "text/x-python",
375
+ "name": "python",
376
+ "nbconvert_exporter": "python",
377
+ "pygments_lexer": "ipython3",
378
+ "version": "3.11.13"
379
+ }
380
+ },
381
+ "nbformat": 4,
382
+ "nbformat_minor": 4
383
+ }
A4_Climate_FEVER.ipynb ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "TB4CfcNaQZFN"
7
+ },
8
+ "source": [
9
+ "# A4: S-BERT Training (Climate-FEVER)\n",
10
+ "## Student Information\n",
11
+ "**Name:** HTUT KO KO \n",
12
+ "**ID:** st126010 \n",
13
+ "\n",
14
+ "## Task 2: S-BERT Implementation\n",
15
+ "In this notebook, I load the pre-trained BERT model (from Task 1) and fine-tune it using a Siamese network structure for Natural Language Inference (NLI) on the Climate-FEVER dataset."
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 9,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "name": "stdout",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "Using device: mps\n"
28
+ ]
29
+ }
30
+ ],
31
+ "source": [
32
+ "import torch\n",
33
+ "import torch.nn as nn\n",
34
+ "import torch.optim as optim\n",
35
+ "import numpy as np\n",
36
+ "from datasets import load_dataset\n",
37
+ "from transformers import BertTokenizer\n",
38
+ "from torch.utils.data import DataLoader, Dataset\n",
39
+ "from sklearn.metrics import classification_report, accuracy_score\n",
40
+ "\n",
41
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
42
+ "print(f\"Using device: {device}\")"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 10,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# Model Hyperparameters (Must match Pre-training)\n",
52
+ "max_len = 128\n",
53
+ "n_layers = 2\n",
54
+ "n_heads = 4\n",
55
+ "d_model = 256\n",
56
+ "d_ff = 256 * 4\n",
57
+ "d_k = d_v = 64\n",
58
+ "n_segments = 2\n",
59
+ "vocab_size = 5004"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "## 1. BERT Model Definition\n",
67
+ "Required to load the saved state dictionary."
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 11,
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "Loaded bert_trained.pt successfully.\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "class Embedding(nn.Module):\n",
85
+ " def __init__(self):\n",
86
+ " super(Embedding, self).__init__()\n",
87
+ " self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
88
+ " self.pos_embed = nn.Embedding(max_len, d_model)\n",
89
+ " self.seg_embed = nn.Embedding(n_segments, d_model)\n",
90
+ " self.norm = nn.LayerNorm(d_model)\n",
91
+ "\n",
92
+ " def forward(self, x, seg):\n",
93
+ " seq_len = x.size(1)\n",
94
+ " pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
95
+ " pos = pos.unsqueeze(0).expand_as(x)\n",
96
+ " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
97
+ " return self.norm(embedding)\n",
98
+ "\n",
99
+ "class MultiHeadAttention(nn.Module):\n",
100
+ " def __init__(self):\n",
101
+ " super(MultiHeadAttention, self).__init__()\n",
102
+ " self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
103
+ " self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
104
+ " self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
105
+ " self.linear = nn.Linear(n_heads * d_v, d_model)\n",
106
+ " self.layer_norm = nn.LayerNorm(d_model)\n",
107
+ "\n",
108
+ " def forward(self, Q, K, V, attn_mask):\n",
109
+ " batch_size = Q.size(0)\n",
110
+ " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
111
+ " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
112
+ " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
113
+ " \n",
114
+ " attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
115
+ " \n",
116
+ " scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
117
+ " scores.masked_fill_(attn_mask, -1e9)\n",
118
+ " attn = nn.Softmax(dim=-1)(scores)\n",
119
+ " context = torch.matmul(attn, v_s)\n",
120
+ " context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
121
+ " output = self.linear(context)\n",
122
+ " return self.layer_norm(output + Q), attn\n",
123
+ "\n",
124
+ "class PoswiseFeedForwardNet(nn.Module):\n",
125
+ " def __init__(self):\n",
126
+ " super(PoswiseFeedForwardNet, self).__init__()\n",
127
+ " self.fc1 = nn.Linear(d_model, d_ff)\n",
128
+ " self.fc2 = nn.Linear(d_ff, d_model)\n",
129
+ "\n",
130
+ " def forward(self, x):\n",
131
+ " return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
132
+ "\n",
133
+ "class EncoderLayer(nn.Module):\n",
134
+ " def __init__(self):\n",
135
+ " super(EncoderLayer, self).__init__()\n",
136
+ " self.enc_self_attn = MultiHeadAttention()\n",
137
+ " self.pos_ffn = PoswiseFeedForwardNet()\n",
138
+ "\n",
139
+ " def forward(self, enc_inputs, enc_self_attn_mask):\n",
140
+ " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
141
+ " enc_outputs = self.pos_ffn(enc_outputs)\n",
142
+ " return enc_outputs, attn\n",
143
+ "\n",
144
+ "def get_attn_pad_mask(seq_q, seq_k):\n",
145
+ " batch_size, len_q = seq_q.size()\n",
146
+ " batch_size, len_k = seq_k.size()\n",
147
+ " pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
148
+ " return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
149
+ "\n",
150
+ "class BERT(nn.Module):\n",
151
+ " def __init__(self):\n",
152
+ " super(BERT, self).__init__()\n",
153
+ " self.embedding = Embedding()\n",
154
+ " self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
155
+ " self.fc = nn.Linear(d_model, d_model)\n",
156
+ " self.activ = nn.Tanh()\n",
157
+ " self.linear = nn.Linear(d_model, d_model)\n",
158
+ " self.norm = nn.LayerNorm(d_model)\n",
159
+ " self.classifier = nn.Linear(d_model, 2)\n",
160
+ " embed_weight = self.embedding.tok_embed.weight\n",
161
+ " n_vocab, n_dim = embed_weight.size()\n",
162
+ " self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
163
+ " self.decoder.weight = embed_weight\n",
164
+ " self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
165
+ "\n",
166
+ " def forward(self, input_ids, segment_ids, masked_pos):\n",
167
+ " output = self.embedding(input_ids, segment_ids)\n",
168
+ " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
169
+ " for layer in self.layers:\n",
170
+ " output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
171
+ " \n",
172
+ " h_pooled = self.activ(self.fc(output[:, 0]))\n",
173
+ " logits_nsp = self.classifier(h_pooled)\n",
174
+ " \n",
175
+ " # For S-BERT, I return the output sequences directly\n",
176
+ " return logits_nsp, logits_nsp, output\n",
177
+ "\n",
178
+ "# Load Pre-trained Parameters\n",
179
+ "bert = BERT().to(device)\n",
180
+ "try:\n",
181
+ " bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
182
+ " print(\"Loaded bert_trained.pt successfully.\")\n",
183
+ "except:\n",
184
+ " print(\"Pre-trained weights not found. Please run A4_BERT.ipynb first.\")"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "markdown",
189
+ "metadata": {},
190
+ "source": [
191
+ "## 2. S-BERT for Climate-FEVER\n",
192
+ "Fine-tuning on the Climate-FEVER dataset."
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": 12,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "class SBERT(nn.Module):\n",
202
+ " def __init__(self, bert_model):\n",
203
+ " super(SBERT, self).__init__()\n",
204
+ " self.bert = bert_model\n",
205
+ " self.classifier = nn.Linear(d_model * 3, 3)\n",
206
+ "\n",
207
+ " def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
208
+ " device = premise_ids.device\n",
209
+ " dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
210
+ " \n",
211
+ " _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
212
+ " mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
213
+ " u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
214
+ "\n",
215
+ " _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
216
+ " mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
217
+ " v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
218
+ "\n",
219
+ " uv_abs = torch.abs(u - v)\n",
220
+ " features = torch.cat([u, v, uv_abs], dim=-1)\n",
221
+ " logits = self.classifier(features)\n",
222
+ " return logits\n",
223
+ "\n",
224
+ "s_model = SBERT(bert).to(device)\n",
225
+ "optimizer = optim.Adam(s_model.parameters(), lr=2e-5)\n",
226
+ "criterion = nn.CrossEntropyLoss()"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "metadata": {},
232
+ "source": [
233
+ "### 2.1 Climate-FEVER fine-tuning\n",
234
+ "I load the Climate-FEVER dataset, split into train/test, and train the model."
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 13,
240
+ "metadata": {},
241
+ "outputs": [
242
+ {
243
+ "name": "stdout",
244
+ "output_type": "stream",
245
+ "text": [
246
+ "Starting S-BERT Training...\n",
247
+ "Epoch 1 Loss: 0.9086\n",
248
+ "Epoch 2 Loss: 0.8578\n",
249
+ "Epoch 3 Loss: 0.8520\n",
250
+ "Epoch 4 Loss: 0.8444\n",
251
+ "Epoch 5 Loss: 0.8365\n",
252
+ "Epoch 6 Loss: 0.8249\n",
253
+ "Epoch 7 Loss: 0.8103\n",
254
+ "Epoch 8 Loss: 0.7932\n",
255
+ "Epoch 9 Loss: 0.7652\n",
256
+ "Epoch 10 Loss: 0.7349\n",
257
+ "Epoch 11 Loss: 0.6856\n",
258
+ "Epoch 12 Loss: 0.6360\n",
259
+ "Epoch 13 Loss: 0.5786\n",
260
+ "Epoch 14 Loss: 0.5181\n",
261
+ "Epoch 15 Loss: 0.4629\n",
262
+ "Epoch 16 Loss: 0.4108\n",
263
+ "Epoch 17 Loss: 0.3392\n",
264
+ "Epoch 18 Loss: 0.3047\n",
265
+ "Epoch 19 Loss: 0.2546\n",
266
+ "Epoch 20 Loss: 0.2053\n",
267
+ "Epoch 21 Loss: 0.1662\n",
268
+ "Epoch 22 Loss: 0.1421\n",
269
+ "Epoch 23 Loss: 0.1161\n",
270
+ "Epoch 24 Loss: 0.0918\n",
271
+ "Epoch 25 Loss: 0.0798\n",
272
+ "Epoch 26 Loss: 0.0740\n",
273
+ "Epoch 27 Loss: 0.0591\n",
274
+ "Epoch 28 Loss: 0.0488\n",
275
+ "Epoch 29 Loss: 0.0448\n",
276
+ "Epoch 30 Loss: 0.0452\n",
277
+ "Epoch 31 Loss: 0.0354\n",
278
+ "Epoch 32 Loss: 0.0315\n",
279
+ "Epoch 33 Loss: 0.0265\n",
280
+ "Epoch 34 Loss: 0.0232\n",
281
+ "Epoch 35 Loss: 0.0215\n",
282
+ "Epoch 36 Loss: 0.0180\n",
283
+ "Epoch 37 Loss: 0.0173\n",
284
+ "Epoch 38 Loss: 0.0147\n",
285
+ "Epoch 39 Loss: 0.0137\n",
286
+ "Epoch 40 Loss: 0.0159\n",
287
+ "Epoch 41 Loss: 0.0127\n",
288
+ "Epoch 42 Loss: 0.0102\n",
289
+ "Epoch 43 Loss: 0.0094\n",
290
+ "Epoch 44 Loss: 0.0094\n",
291
+ "Epoch 45 Loss: 0.0100\n",
292
+ "Epoch 46 Loss: 0.0112\n",
293
+ "Epoch 47 Loss: 0.0077\n",
294
+ "Epoch 48 Loss: 0.0067\n",
295
+ "Epoch 49 Loss: 0.0073\n",
296
+ "Epoch 50 Loss: 0.0268\n",
297
+ "Epoch 51 Loss: 0.0747\n",
298
+ "Epoch 52 Loss: 0.0405\n",
299
+ "Epoch 53 Loss: 0.0241\n",
300
+ "Epoch 54 Loss: 0.0077\n",
301
+ "Epoch 55 Loss: 0.0054\n",
302
+ "Epoch 56 Loss: 0.0049\n",
303
+ "Epoch 57 Loss: 0.0047\n",
304
+ "Epoch 58 Loss: 0.0047\n",
305
+ "Epoch 59 Loss: 0.0052\n",
306
+ "Epoch 60 Loss: 0.0038\n",
307
+ "Epoch 61 Loss: 0.0037\n",
308
+ "Epoch 62 Loss: 0.0039\n",
309
+ "Epoch 63 Loss: 0.0037\n",
310
+ "Epoch 64 Loss: 0.0049\n",
311
+ "Epoch 65 Loss: 0.0036\n",
312
+ "Epoch 66 Loss: 0.0035\n",
313
+ "Epoch 67 Loss: 0.0037\n",
314
+ "Epoch 68 Loss: 0.0040\n",
315
+ "Epoch 69 Loss: 0.0044\n",
316
+ "Epoch 70 Loss: 0.0038\n",
317
+ "Epoch 71 Loss: 0.0079\n",
318
+ "Epoch 72 Loss: 0.0093\n",
319
+ "Epoch 73 Loss: 0.0033\n",
320
+ "Epoch 74 Loss: 0.0030\n",
321
+ "Epoch 75 Loss: 0.0032\n",
322
+ "Epoch 76 Loss: 0.0032\n",
323
+ "Epoch 77 Loss: 0.0029\n",
324
+ "Epoch 78 Loss: 0.0026\n",
325
+ "Epoch 79 Loss: 0.0031\n",
326
+ "Epoch 80 Loss: 0.0021\n",
327
+ "Epoch 81 Loss: 0.0048\n",
328
+ "Epoch 82 Loss: 0.0030\n",
329
+ "Epoch 83 Loss: 0.0045\n",
330
+ "Epoch 84 Loss: 0.0045\n",
331
+ "Epoch 85 Loss: 0.0061\n",
332
+ "Epoch 86 Loss: 0.0784\n",
333
+ "Epoch 87 Loss: 0.0706\n",
334
+ "Epoch 88 Loss: 0.0293\n",
335
+ "Epoch 89 Loss: 0.0059\n",
336
+ "Epoch 90 Loss: 0.0028\n",
337
+ "Epoch 91 Loss: 0.0024\n",
338
+ "Epoch 92 Loss: 0.0023\n",
339
+ "Epoch 93 Loss: 0.0020\n",
340
+ "Epoch 94 Loss: 0.0019\n",
341
+ "Epoch 95 Loss: 0.0018\n",
342
+ "Epoch 96 Loss: 0.0015\n",
343
+ "Epoch 97 Loss: 0.0016\n",
344
+ "Epoch 98 Loss: 0.0019\n",
345
+ "Epoch 99 Loss: 0.0014\n",
346
+ "Epoch 100 Loss: 0.0017\n"
347
+ ]
348
+ }
349
+ ],
350
+ "source": [
351
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
352
+ "cf_dataset = load_dataset('climate_fever', split='test') \n",
353
+ "# Climate-FEVER only has 'test' split publicly available often. I'll use it as our full dataset.\n",
354
+ "cf_split = cf_dataset.train_test_split(test_size=0.2, seed=42)\n",
355
+ "train_dataset = cf_split['train']\n",
356
+ "test_dataset = cf_split['test']\n",
357
+ "\n",
358
+ "class NLIDataset(Dataset):\n",
359
+ " def __init__(self, dataset, tokenizer, max_len=128):\n",
360
+ " self.dataset = dataset\n",
361
+ " self.tokenizer = tokenizer\n",
362
+ " self.max_len = max_len\n",
363
+ "\n",
364
+ " def __getitem__(self, idx):\n",
365
+ " item = self.dataset[idx]\n",
366
+ " premise = item['claim']\n",
367
+ " # Climate-FEVER has 'evidences' list. I take the first evidence text.\n",
368
+ " evidence_data = item['evidences'][0]\n",
369
+ " hypothesis = evidence_data['evidence']\n",
370
+ " label_raw = evidence_data['evidence_label'] # Nested access\n",
371
+ "\n",
372
+ " # Robust label mapping (Handles integers 0/1/2 and strings)\n",
373
+ " # Target: 0: Entailment, 1: Neutral, 2: Contradiction\n",
374
+ " if isinstance(label_raw, int):\n",
375
+ " # Assuming HF Climate-FEVER uses: 0: Supports, 1: Refutes, 2: NEI\n",
376
+ " if label_raw == 0: label = 0 # Supports -> Entailment\n",
377
+ " elif label_raw == 1: label = 2 # Refutes -> Contradiction\n",
378
+ " elif label_raw == 2: label = 1 # NEI -> Neutral\n",
379
+ " else: label = 1 # Default\n",
380
+ " else:\n",
381
+ " label_str = str(label_raw).upper().replace(\" \", \"_\")\n",
382
+ " if 'SUPPORT' in label_str: label = 0\n",
383
+ " elif 'REFUTE' in label_str: label = 2\n",
384
+ " elif 'INFO' in label_str: label = 1\n",
385
+ " else: label = 1\n",
386
+ "\n",
387
+ " encoded_premise = self.tokenizer(\n",
388
+ " premise,\n",
389
+ " add_special_tokens=True,\n",
390
+ " max_length=self.max_len,\n",
391
+ " padding='max_length',\n",
392
+ " return_attention_mask=True,\n",
393
+ " truncation=True\n",
394
+ " )\n",
395
+ "\n",
396
+ " encoded_hypothesis = self.tokenizer(\n",
397
+ " hypothesis,\n",
398
+ " add_special_tokens=True,\n",
399
+ " max_length=self.max_len,\n",
400
+ " padding='max_length',\n",
401
+ " return_attention_mask=True,\n",
402
+ " truncation=True\n",
403
+ " )\n",
404
+ "\n",
405
+ " return {\n",
406
+ " 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
407
+ " 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
408
+ " 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
409
+ " 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
410
+ " 'label': torch.tensor(label, dtype=torch.long)\n",
411
+ " }\n",
412
+ "\n",
413
+ " def __len__(self):\n",
414
+ " return len(self.dataset)\n",
415
+ "\n",
416
+ "train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
417
+ "test_loader = DataLoader(NLIDataset(test_dataset, tokenizer), batch_size=16, shuffle=False)\n",
418
+ "\n",
419
+ "print(\"Starting S-BERT Training...\")\n",
420
+ "for epoch in range(100): \n",
421
+ " s_model.train()\n",
422
+ " total_loss = 0\n",
423
+ " for batch in train_loader:\n",
424
+ " p_ids = batch['premise_input_ids'].to(device)\n",
425
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
426
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
427
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
428
+ " labels = batch['label'].to(device)\n",
429
+ "\n",
430
+ " optimizer.zero_grad()\n",
431
+ " logits = s_model(p_ids, p_seg, h_ids, h_seg)\n",
432
+ " loss = criterion(logits, labels)\n",
433
+ " loss.backward()\n",
434
+ " optimizer.step()\n",
435
+ " total_loss += loss.item()\n",
436
+ " print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
437
+ "\n",
438
+ "torch.save(s_model.state_dict(), './models/sbert_climate_fever.pt')"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {},
444
+ "source": [
445
+ "## 3. Evaluation\n",
446
+ "Evaluation of the model on the held-out test set."
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": 14,
452
+ "metadata": {},
453
+ "outputs": [
454
+ {
455
+ "name": "stdout",
456
+ "output_type": "stream",
457
+ "text": [
458
+ "Evaluating...\n",
459
+ "Classification Report:\n",
460
+ " precision recall f1-score support\n",
461
+ "\n",
462
+ " Entailment 0.26 0.22 0.24 82\n",
463
+ " Neutral 0.62 0.69 0.65 191\n",
464
+ "Contradiction 0.19 0.15 0.17 34\n",
465
+ "\n",
466
+ " accuracy 0.50 307\n",
467
+ " macro avg 0.36 0.35 0.35 307\n",
468
+ " weighted avg 0.48 0.50 0.49 307\n",
469
+ "\n",
470
+ "Accuracy: 0.5049\n"
471
+ ]
472
+ }
473
+ ],
474
+ "source": [
475
+ "s_model.eval()\n",
476
+ "all_preds = []\n",
477
+ "all_labels = []\n",
478
+ "\n",
479
+ "print(\"Evaluating...\")\n",
480
+ "with torch.no_grad():\n",
481
+ " for batch in test_loader:\n",
482
+ " p_ids = batch['premise_input_ids'].to(device)\n",
483
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
484
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
485
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
486
+ " labels = batch['label'].to(device)\n",
487
+ "\n",
488
+ " logits = s_model(p_ids, p_seg, h_ids, h_seg)\n",
489
+ " preds = torch.argmax(logits, dim=1)\n",
490
+ " \n",
491
+ " all_preds.extend(preds.cpu().numpy())\n",
492
+ " all_labels.extend(labels.cpu().numpy())\n",
493
+ "\n",
494
+ "target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
495
+ "print(\"Classification Report:\")\n",
496
+ "print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
497
+ "print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
498
+ ]
499
+ }
500
+ ],
501
+ "metadata": {
502
+ "kernelspec": {
503
+ "display_name": "Python 3",
504
+ "language": "python",
505
+ "name": "python3"
506
+ },
507
+ "language_info": {
508
+ "codemirror_mode": {
509
+ "name": "ipython",
510
+ "version": 3
511
+ },
512
+ "file_extension": ".py",
513
+ "mimetype": "text/x-python",
514
+ "name": "python",
515
+ "nbconvert_exporter": "python",
516
+ "pygments_lexer": "ipython3",
517
+ "version": "3.8.5"
518
+ }
519
+ },
520
+ "nbformat": 4,
521
+ "nbformat_minor": 4
522
+ }
A4_Option_MNLI.ipynb ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# A4: S-BERT Training on Alternative Datasets (MNLI)\n",
8
+ "\n",
9
+ "This notebook allows me to train the S-BERT model on the **MNLI** (Multi-Genre Natural Language Inference) dataset.\n",
10
+ "\n",
11
+ "## 1. Environment Setup"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 8,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "Using device: mps\n"
24
+ ]
25
+ }
26
+ ],
27
+ "source": [
28
+ "import os\n",
29
+ "import torch\n",
30
+ "import torch.nn as nn\n",
31
+ "import torch.optim as optim\n",
32
+ "import numpy as np\n",
33
+ "from datasets import load_dataset\n",
34
+ "from transformers import BertTokenizer\n",
35
+ "from torch.utils.data import DataLoader, Dataset\n",
36
+ "from sklearn.metrics import classification_report, accuracy_score\n",
37
+ "\n",
38
+ "# Device Configuration\n",
39
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
40
+ "print(f\"Using device: {device}\")"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {},
46
+ "source": [
47
+ "## 2. Load Pre-trained BERT\n",
48
+ "\n",
49
+ "I will load the BERT model trained in `A4_BERT.ipynb`. Ensure `models/bert_trained.pt` exists."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 9,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stdout",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Loaded bert_trained.pt\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "# Define BERT Architecture\n",
67
+ "# MUST MATCH THE OPTIMIZED CONFIG FROM A4_BERT.ipynb\n",
68
+ "vocab_size = 5004 # Updated from 30522\n",
69
+ "d_model = 256 # MiniBERT Config\n",
70
+ "n_layers = 2 # Updated from 4\n",
71
+ "n_heads = 4\n",
72
+ "d_ff = 256 * 4\n",
73
+ "max_len = 128\n",
74
+ "n_segments = 2\n",
75
+ "d_k = d_v = 64\n",
76
+ "\n",
77
+ "class Embedding(nn.Module):\n",
78
+ " def __init__(self):\n",
79
+ " super(Embedding, self).__init__()\n",
80
+ " self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
81
+ " self.pos_embed = nn.Embedding(max_len, d_model)\n",
82
+ " self.seg_embed = nn.Embedding(n_segments, d_model)\n",
83
+ " self.norm = nn.LayerNorm(d_model)\n",
84
+ "\n",
85
+ " def forward(self, x, seg):\n",
86
+ " seq_len = x.size(1)\n",
87
+ " pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
88
+ " pos = pos.unsqueeze(0).expand_as(x)\n",
89
+ " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
90
+ " return self.norm(embedding)\n",
91
+ "\n",
92
+ "class MultiHeadAttention(nn.Module):\n",
93
+ " def __init__(self):\n",
94
+ " super(MultiHeadAttention, self).__init__()\n",
95
+ " self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
96
+ " self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
97
+ " self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
98
+ " self.linear = nn.Linear(n_heads * d_v, d_model)\n",
99
+ " self.layer_norm = nn.LayerNorm(d_model)\n",
100
+ "\n",
101
+ " def forward(self, Q, K, V, attn_mask):\n",
102
+ " batch_size = Q.size(0)\n",
103
+ " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)\n",
104
+ " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)\n",
105
+ " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)\n",
106
+ " \n",
107
+ " attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
108
+ " \n",
109
+ " scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
110
+ " scores.masked_fill_(attn_mask, -1e9)\n",
111
+ " attn = nn.Softmax(dim=-1)(scores)\n",
112
+ " context = torch.matmul(attn, v_s)\n",
113
+ " context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
114
+ " output = self.linear(context)\n",
115
+ " return self.layer_norm(output + Q), attn\n",
116
+ "\n",
117
+ "class PoswiseFeedForwardNet(nn.Module):\n",
118
+ " def __init__(self):\n",
119
+ " super(PoswiseFeedForwardNet, self).__init__()\n",
120
+ " self.fc1 = nn.Linear(d_model, d_ff)\n",
121
+ " self.fc2 = nn.Linear(d_ff, d_model)\n",
122
+ " def forward(self, x):\n",
123
+ " return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
124
+ "\n",
125
+ "class EncoderLayer(nn.Module):\n",
126
+ " def __init__(self):\n",
127
+ " super(EncoderLayer, self).__init__()\n",
128
+ " self.enc_self_attn = MultiHeadAttention()\n",
129
+ " self.pos_ffn = PoswiseFeedForwardNet()\n",
130
+ " def forward(self, enc_inputs, enc_self_attn_mask):\n",
131
+ " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
132
+ " enc_outputs = self.pos_ffn(enc_outputs)\n",
133
+ " return enc_outputs, attn\n",
134
+ "\n",
135
+ "def get_attn_pad_mask(seq_q, seq_k):\n",
136
+ " batch_size, len_q = seq_q.size()\n",
137
+ " batch_size, len_k = seq_k.size()\n",
138
+ " pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
139
+ " return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
140
+ "\n",
141
+ "class BERT(nn.Module):\n",
142
+ " def __init__(self):\n",
143
+ " super(BERT, self).__init__()\n",
144
+ " self.embedding = Embedding()\n",
145
+ " self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
146
+ " self.fc = nn.Linear(d_model, d_model)\n",
147
+ " self.activ = nn.Tanh()\n",
148
+ " self.linear = nn.Linear(d_model, d_model)\n",
149
+ " self.norm = nn.LayerNorm(d_model)\n",
150
+ " self.classifier = nn.Linear(d_model, 2)\n",
151
+ " embed_weight = self.embedding.tok_embed.weight\n",
152
+ " n_vocab, n_dim = embed_weight.size()\n",
153
+ " self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
154
+ " self.decoder.weight = embed_weight\n",
155
+ " self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
156
+ "\n",
157
+ " def forward(self, input_ids, segment_ids, masked_pos=None):\n",
158
+ " output = self.embedding(input_ids, segment_ids)\n",
159
+ " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
160
+ " for layer in self.layers:\n",
161
+ " output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
162
+ " return None, None, output \n",
163
+ "\n",
164
+ "# Load Pretrained Weights\n",
165
+ "bert = BERT().to(device)\n",
166
+ "try:\n",
167
+ " bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
168
+ " print(\"Loaded bert_trained.pt\")\n",
169
+ "except:\n",
170
+ " print(\"Warning: bert_trained.pt not found. Using random weights.\")\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "## 3. Load MNLI Dataset\n"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 10,
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "Loading mnli...\n",
190
+ "Loaded dataset keys: dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])\n",
191
+ "Train size: 10000, Val size: 1000\n"
192
+ ]
193
+ }
194
+ ],
195
+ "source": [
196
+ "DATASET_NAME = 'mnli'\n",
197
+ "print(f\"Loading {DATASET_NAME}...\")\n",
198
+ "# MNLI is part of GLUE benchmark\n",
199
+ "dataset = load_dataset('glue', 'mnli') \n",
200
+ "print(f\"Loaded dataset keys: {dataset.keys()}\")\n",
201
+ "\n",
202
+ "train_dataset = dataset['train'].select(range(10000))\n",
203
+ "val_dataset = dataset['validation_matched'].select(range(1000))\n",
204
+ "print(f\"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}\")"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 11,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "# Data Loader\n",
214
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
215
+ "\n",
216
+ "class NLIDataset(Dataset):\n",
217
+ " def __init__(self, dataset, tokenizer, max_len=128):\n",
218
+ " self.dataset = dataset\n",
219
+ " self.tokenizer = tokenizer\n",
220
+ " self.max_len = max_len\n",
221
+ "\n",
222
+ " def __len__(self):\n",
223
+ " return len(self.dataset)\n",
224
+ "\n",
225
+ " def __getitem__(self, idx):\n",
226
+ " item = self.dataset[idx]\n",
227
+ " premise = item['premise']\n",
228
+ " hypothesis = item['hypothesis']\n",
229
+ " label = item['label']\n",
230
+ "\n",
231
+ " encoded_premise = self.tokenizer(\n",
232
+ " premise,\n",
233
+ " add_special_tokens=True,\n",
234
+ " max_length=self.max_len,\n",
235
+ " padding='max_length',\n",
236
+ " return_attention_mask=True,\n",
237
+ " truncation=True\n",
238
+ " )\n",
239
+ "\n",
240
+ " encoded_hypothesis = self.tokenizer(\n",
241
+ " hypothesis,\n",
242
+ " add_special_tokens=True,\n",
243
+ " max_length=self.max_len,\n",
244
+ " padding='max_length',\n",
245
+ " return_attention_mask=True,\n",
246
+ " truncation=True\n",
247
+ " )\n",
248
+ "\n",
249
+ " return {\n",
250
+ " 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
251
+ " 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
252
+ " 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
253
+ " 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
254
+ " 'label': torch.tensor(label, dtype=torch.long)\n",
255
+ " }\n",
256
+ "\n",
257
+ "train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
258
+ "test_loader = DataLoader(NLIDataset(val_dataset, tokenizer), batch_size=16, shuffle=False)"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 12,
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "# S-BERT Model\n",
268
+ "class SBERT(nn.Module):\n",
269
+ " def __init__(self, bert_model):\n",
270
+ " super(SBERT, self).__init__()\n",
271
+ " self.bert = bert_model\n",
272
+ " self.classifier = nn.Linear(d_model * 3, 3)\n",
273
+ "\n",
274
+ " def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
275
+ " device = premise_ids.device\n",
276
+ " dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
277
+ " \n",
278
+ " _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
279
+ " mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
280
+ " u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
281
+ "\n",
282
+ " _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
283
+ " mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
284
+ " v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
285
+ "\n",
286
+ " uv_abs = torch.abs(u - v)\n",
287
+ " features = torch.cat([u, v, uv_abs], dim=-1)\n",
288
+ " logits = self.classifier(features)\n",
289
+ " return logits\n",
290
+ "\n",
291
+ "sbert = SBERT(bert).to(device)\n",
292
+ "optimizer = optim.Adam(sbert.parameters(), lr=2e-5)\n",
293
+ "criterion = nn.CrossEntropyLoss()"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 13,
299
+ "metadata": {},
300
+ "outputs": [
301
+ {
302
+ "name": "stdout",
303
+ "output_type": "stream",
304
+ "text": [
305
+ "Starting Training...\n",
306
+ "Epoch 1 Loss: 1.0874\n",
307
+ "Epoch 2 Loss: 1.0579\n",
308
+ "Epoch 3 Loss: 1.0135\n",
309
+ "Epoch 4 Loss: 0.9806\n",
310
+ "Epoch 5 Loss: 0.9516\n",
311
+ "Epoch 6 Loss: 0.9264\n",
312
+ "Epoch 7 Loss: 0.8967\n",
313
+ "Epoch 8 Loss: 0.8670\n",
314
+ "Epoch 9 Loss: 0.8326\n",
315
+ "Epoch 10 Loss: 0.8005\n",
316
+ "Epoch 11 Loss: 0.7601\n",
317
+ "Epoch 12 Loss: 0.7218\n",
318
+ "Epoch 13 Loss: 0.6797\n",
319
+ "Epoch 14 Loss: 0.6349\n",
320
+ "Epoch 15 Loss: 0.5906\n",
321
+ "Epoch 16 Loss: 0.5450\n",
322
+ "Epoch 17 Loss: 0.4974\n",
323
+ "Epoch 18 Loss: 0.4493\n",
324
+ "Epoch 19 Loss: 0.4065\n",
325
+ "Epoch 20 Loss: 0.3590\n",
326
+ "Epoch 21 Loss: 0.3150\n",
327
+ "Epoch 22 Loss: 0.2712\n",
328
+ "Epoch 23 Loss: 0.2321\n",
329
+ "Epoch 24 Loss: 0.2020\n",
330
+ "Epoch 25 Loss: 0.1673\n",
331
+ "Epoch 26 Loss: 0.1349\n",
332
+ "Epoch 27 Loss: 0.1136\n",
333
+ "Epoch 28 Loss: 0.0965\n",
334
+ "Epoch 29 Loss: 0.0876\n",
335
+ "Epoch 30 Loss: 0.0723\n",
336
+ "Epoch 31 Loss: 0.0656\n",
337
+ "Epoch 32 Loss: 0.0504\n",
338
+ "Epoch 33 Loss: 0.0416\n",
339
+ "Epoch 34 Loss: 0.0418\n",
340
+ "Epoch 35 Loss: 0.0327\n",
341
+ "Epoch 36 Loss: 0.0324\n",
342
+ "Epoch 37 Loss: 0.0269\n",
343
+ "Epoch 38 Loss: 0.0438\n",
344
+ "Epoch 39 Loss: 0.0291\n",
345
+ "Epoch 40 Loss: 0.0210\n",
346
+ "Epoch 41 Loss: 0.0168\n",
347
+ "Epoch 42 Loss: 0.0309\n",
348
+ "Epoch 43 Loss: 0.0180\n",
349
+ "Epoch 44 Loss: 0.0327\n",
350
+ "Epoch 45 Loss: 0.0411\n",
351
+ "Epoch 46 Loss: 0.0157\n",
352
+ "Epoch 47 Loss: 0.0048\n",
353
+ "Epoch 48 Loss: 0.0019\n",
354
+ "Epoch 49 Loss: 0.0013\n",
355
+ "Epoch 50 Loss: 0.0010\n",
356
+ "Epoch 51 Loss: 0.0008\n",
357
+ "Epoch 52 Loss: 0.0007\n",
358
+ "Epoch 53 Loss: 0.0005\n",
359
+ "Epoch 54 Loss: 0.0004\n",
360
+ "Epoch 55 Loss: 0.0003\n",
361
+ "Epoch 56 Loss: 0.0002\n",
362
+ "Epoch 57 Loss: 0.0002\n",
363
+ "Epoch 58 Loss: 0.0001\n",
364
+ "Epoch 59 Loss: 0.0001\n",
365
+ "Epoch 60 Loss: 0.0001\n",
366
+ "Epoch 61 Loss: 0.0000\n",
367
+ "Epoch 62 Loss: 0.0000\n",
368
+ "Epoch 63 Loss: 0.0000\n",
369
+ "Epoch 64 Loss: 0.0000\n",
370
+ "Epoch 65 Loss: 0.0000\n",
371
+ "Epoch 66 Loss: 0.1953\n",
372
+ "Epoch 67 Loss: 0.0272\n",
373
+ "Epoch 68 Loss: 0.0120\n",
374
+ "Epoch 69 Loss: 0.0108\n",
375
+ "Epoch 70 Loss: 0.0152\n",
376
+ "Epoch 71 Loss: 0.0337\n",
377
+ "Epoch 72 Loss: 0.0215\n",
378
+ "Epoch 73 Loss: 0.0148\n",
379
+ "Epoch 74 Loss: 0.0207\n",
380
+ "Epoch 75 Loss: 0.0238\n",
381
+ "Epoch 76 Loss: 0.0181\n",
382
+ "Epoch 77 Loss: 0.0217\n",
383
+ "Epoch 78 Loss: 0.0136\n",
384
+ "Epoch 79 Loss: 0.0163\n",
385
+ "Epoch 80 Loss: 0.0067\n",
386
+ "Epoch 81 Loss: 0.0007\n",
387
+ "Epoch 82 Loss: 0.0003\n",
388
+ "Epoch 83 Loss: 0.0002\n",
389
+ "Epoch 84 Loss: 0.0002\n",
390
+ "Epoch 85 Loss: 0.0001\n",
391
+ "Epoch 86 Loss: 0.0001\n",
392
+ "Epoch 87 Loss: 0.0001\n",
393
+ "Epoch 88 Loss: 0.0001\n",
394
+ "Epoch 89 Loss: 0.0001\n",
395
+ "Epoch 90 Loss: 0.0000\n",
396
+ "Epoch 91 Loss: 0.0000\n",
397
+ "Epoch 92 Loss: 0.0000\n",
398
+ "Epoch 93 Loss: 0.0000\n",
399
+ "Epoch 94 Loss: 0.0000\n",
400
+ "Epoch 95 Loss: 0.0000\n",
401
+ "Epoch 96 Loss: 0.0000\n",
402
+ "Epoch 97 Loss: 0.0000\n",
403
+ "Epoch 98 Loss: 0.0000\n",
404
+ "Epoch 99 Loss: 0.0000\n",
405
+ "Epoch 100 Loss: 0.0000\n",
406
+ "Done!\n"
407
+ ]
408
+ }
409
+ ],
410
+ "source": [
411
+ "# Train Loop\n",
412
+ "print(\"Starting Training...\")\n",
413
+ "epochs = 100 \n",
414
+ "for epoch in range(epochs):\n",
415
+ " sbert.train()\n",
416
+ " total_loss = 0\n",
417
+ " for batch in train_loader:\n",
418
+ " p_ids = batch['premise_input_ids'].to(device)\n",
419
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
420
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
421
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
422
+ " labels = batch['label'].to(device)\n",
423
+ "\n",
424
+ " optimizer.zero_grad()\n",
425
+ " logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
426
+ " loss = criterion(logits, labels)\n",
427
+ " loss.backward()\n",
428
+ " optimizer.step()\n",
429
+ " total_loss += loss.item()\n",
430
+ " print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
431
+ "\n",
432
+ "print(\"Done!\")\n",
433
+ "torch.save(sbert.state_dict(), f'./models/sbert_{DATASET_NAME}.pt')"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "markdown",
438
+ "metadata": {},
439
+ "source": [
440
+ "## 4. Evaluation\n",
441
+ "\n",
442
+ "Evaluate on validation set (matched)."
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "execution_count": 14,
448
+ "metadata": {},
449
+ "outputs": [
450
+ {
451
+ "name": "stdout",
452
+ "output_type": "stream",
453
+ "text": [
454
+ "Evaluating...\n",
455
+ "Classification Report:\n",
456
+ " precision recall f1-score support\n",
457
+ "\n",
458
+ " Entailment 0.42 0.45 0.43 341\n",
459
+ " Neutral 0.42 0.34 0.37 319\n",
460
+ "Contradiction 0.49 0.54 0.51 340\n",
461
+ "\n",
462
+ " accuracy 0.44 1000\n",
463
+ " macro avg 0.44 0.44 0.44 1000\n",
464
+ " weighted avg 0.44 0.44 0.44 1000\n",
465
+ "\n",
466
+ "Accuracy: 0.4440\n"
467
+ ]
468
+ }
469
+ ],
470
+ "source": [
471
+ "sbert.eval()\n",
472
+ "all_preds = []\n",
473
+ "all_labels = []\n",
474
+ "\n",
475
+ "print(\"Evaluating...\")\n",
476
+ "with torch.no_grad():\n",
477
+ " for batch in test_loader:\n",
478
+ " p_ids = batch['premise_input_ids'].to(device)\n",
479
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
480
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
481
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
482
+ " labels = batch['label'].to(device)\n",
483
+ "\n",
484
+ " logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
485
+ " preds = torch.argmax(logits, dim=1)\n",
486
+ " \n",
487
+ " all_preds.extend(preds.cpu().numpy())\n",
488
+ " all_labels.extend(labels.cpu().numpy())\n",
489
+ "\n",
490
+ "target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
491
+ "print(\"Classification Report:\")\n",
492
+ "print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
493
+ "print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
494
+ ]
495
+ }
496
+ ],
497
+ "metadata": {
498
+ "kernelspec": {
499
+ "display_name": "Python 3",
500
+ "language": "python",
501
+ "name": "python3"
502
+ },
503
+ "language_info": {
504
+ "codemirror_mode": {
505
+ "name": "ipython",
506
+ "version": 3
507
+ },
508
+ "file_extension": ".py",
509
+ "mimetype": "text/x-python",
510
+ "name": "python",
511
+ "nbconvert_exporter": "python",
512
+ "pygments_lexer": "ipython3",
513
+ "version": "3.8.5"
514
+ }
515
+ },
516
+ "nbformat": 4,
517
+ "nbformat_minor": 4
518
+ }
A4_Option_SNLI.ipynb ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# A4: S-BERT Training on Alternative Datasets (SNLI)\n",
8
+ "\n",
9
+ "This notebook allows me to train the S-BERT model on the **SNLI** (Stanford Natural Language Inference) dataset.\n",
10
+ "\n",
11
+ "## 1. Environment Setup"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 9,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "Using device: mps\n"
24
+ ]
25
+ }
26
+ ],
27
+ "source": [
28
+ "import os\n",
29
+ "import torch\n",
30
+ "import torch.nn as nn\n",
31
+ "import torch.optim as optim\n",
32
+ "import numpy as np\n",
33
+ "from datasets import load_dataset\n",
34
+ "from transformers import BertTokenizer\n",
35
+ "from torch.utils.data import DataLoader, Dataset\n",
36
+ "from sklearn.metrics import classification_report, accuracy_score\n",
37
+ "\n",
38
+ "# Device Configuration\n",
39
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
40
+ "print(f\"Using device: {device}\")"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "metadata": {},
46
+ "source": [
47
+ "## 2. Load Pre-trained BERT\n",
48
+ "\n",
49
+ "I will load the BERT model trained in `A4_BERT.ipynb`. Ensure `models/bert_trained.pt` exists."
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 10,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stdout",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Loaded bert_trained.pt\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "# Define BERT Architecture\n",
67
+ "# MUST MATCH THE OPTIMIZED CONFIG FROM A4_BERT.ipynb\n",
68
+ "vocab_size = 5004 # Updated from 30522\n",
69
+ "d_model = 256 # MiniBERT Config\n",
70
+ "n_layers = 2 # Updated from 4\n",
71
+ "n_heads = 4\n",
72
+ "d_ff = 256 * 4\n",
73
+ "max_len = 128\n",
74
+ "n_segments = 2\n",
75
+ "d_k = d_v = 64\n",
76
+ "\n",
77
+ "class Embedding(nn.Module):\n",
78
+ " def __init__(self):\n",
79
+ " super(Embedding, self).__init__()\n",
80
+ " self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
81
+ " self.pos_embed = nn.Embedding(max_len, d_model)\n",
82
+ " self.seg_embed = nn.Embedding(n_segments, d_model)\n",
83
+ " self.norm = nn.LayerNorm(d_model)\n",
84
+ "\n",
85
+ " def forward(self, x, seg):\n",
86
+ " seq_len = x.size(1)\n",
87
+ " pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
88
+ " pos = pos.unsqueeze(0).expand_as(x)\n",
89
+ " embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
90
+ " return self.norm(embedding)\n",
91
+ "\n",
92
+ "class MultiHeadAttention(nn.Module):\n",
93
+ " def __init__(self):\n",
94
+ " super(MultiHeadAttention, self).__init__()\n",
95
+ " self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
96
+ " self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
97
+ " self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
98
+ " self.linear = nn.Linear(n_heads * d_v, d_model)\n",
99
+ " self.layer_norm = nn.LayerNorm(d_model)\n",
100
+ "\n",
101
+ " def forward(self, Q, K, V, attn_mask):\n",
102
+ " batch_size = Q.size(0)\n",
103
+ " q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
104
+ " k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
105
+ " v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
106
+ " \n",
107
+ " attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
108
+ " \n",
109
+ " scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
110
+ " scores.masked_fill_(attn_mask, -1e9)\n",
111
+ " attn = nn.Softmax(dim=-1)(scores)\n",
112
+ " context = torch.matmul(attn, v_s)\n",
113
+ " context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
114
+ " output = self.linear(context)\n",
115
+ " return self.layer_norm(output + Q), attn\n",
116
+ "\n",
117
+ "class PoswiseFeedForwardNet(nn.Module):\n",
118
+ " def __init__(self):\n",
119
+ " super(PoswiseFeedForwardNet, self).__init__()\n",
120
+ " self.fc1 = nn.Linear(d_model, d_ff)\n",
121
+ " self.fc2 = nn.Linear(d_ff, d_model)\n",
122
+ "\n",
123
+ " def forward(self, x):\n",
124
+ " return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
125
+ "\n",
126
+ "class EncoderLayer(nn.Module):\n",
127
+ " def __init__(self):\n",
128
+ " super(EncoderLayer, self).__init__()\n",
129
+ " self.enc_self_attn = MultiHeadAttention()\n",
130
+ " self.pos_ffn = PoswiseFeedForwardNet()\n",
131
+ "\n",
132
+ " def forward(self, enc_inputs, enc_self_attn_mask):\n",
133
+ " enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
134
+ " enc_outputs = self.pos_ffn(enc_outputs)\n",
135
+ " return enc_outputs, attn\n",
136
+ "\n",
137
+ "def get_attn_pad_mask(seq_q, seq_k):\n",
138
+ " batch_size, len_q = seq_q.size()\n",
139
+ " batch_size, len_k = seq_k.size()\n",
140
+ " pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
141
+ " return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
142
+ "\n",
143
+ "class BERT(nn.Module):\n",
144
+ " def __init__(self):\n",
145
+ " super(BERT, self).__init__()\n",
146
+ " self.embedding = Embedding()\n",
147
+ " self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
148
+ " self.fc = nn.Linear(d_model, d_model)\n",
149
+ " self.activ = nn.Tanh()\n",
150
+ " self.linear = nn.Linear(d_model, d_model)\n",
151
+ " self.norm = nn.LayerNorm(d_model)\n",
152
+ " self.classifier = nn.Linear(d_model, 2)\n",
153
+ " embed_weight = self.embedding.tok_embed.weight\n",
154
+ " n_vocab, n_dim = embed_weight.size()\n",
155
+ " self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
156
+ " self.decoder.weight = embed_weight\n",
157
+ " self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
158
+ "\n",
159
+ " def forward(self, input_ids, segment_ids, masked_pos=None):\n",
160
+ " output = self.embedding(input_ids, segment_ids)\n",
161
+ " enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
162
+ " for layer in self.layers:\n",
163
+ " output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
164
+ " return None, None, output \n",
165
+ "\n",
166
+ "# Load Pretrained Weights\n",
167
+ "bert = BERT().to(device)\n",
168
+ "try:\n",
169
+ " bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
170
+ " print(\"Loaded bert_trained.pt\")\n",
171
+ "except:\n",
172
+ " print(\"Warning: bert_trained.pt not found. Using random weights.\")\n"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "metadata": {},
178
+ "source": [
179
+ "## 3. Load SNLI Dataset\n"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 11,
185
+ "metadata": {},
186
+ "outputs": [
187
+ {
188
+ "name": "stdout",
189
+ "output_type": "stream",
190
+ "text": [
191
+ "Loading snli...\n"
192
+ ]
193
+ },
194
+ {
195
+ "name": "stderr",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
199
+ ]
200
+ },
201
+ {
202
+ "name": "stdout",
203
+ "output_type": "stream",
204
+ "text": [
205
+ "Loaded dataset keys: dict_keys(['test', 'validation', 'train'])\n",
206
+ "Train size: 9988, Test size: 988\n"
207
+ ]
208
+ }
209
+ ],
210
+ "source": [
211
+ "DATASET_NAME = 'snli'\n",
212
+ "print(f\"Loading {DATASET_NAME}...\")\n",
213
+ "dataset = load_dataset(DATASET_NAME)\n",
214
+ "print(f\"Loaded dataset keys: {dataset.keys()}\")\n",
215
+ "\n",
216
+ "train_dataset = dataset['train'].select(range(10000))\n",
217
+ "test_dataset = dataset['test'].select(range(1000))\n",
218
+ "\n",
219
+ "# Filter undefined labels\n",
220
+ "train_dataset = train_dataset.filter(lambda x: x['label'] != -1)\n",
221
+ "test_dataset = test_dataset.filter(lambda x: x['label'] != -1)\n",
222
+ "\n",
223
+ "print(f\"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}\")"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 12,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "# Data Loader\n",
233
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
234
+ "\n",
235
+ "class NLIDataset(Dataset):\n",
236
+ " def __init__(self, dataset, tokenizer, max_len=128):\n",
237
+ " self.dataset = dataset\n",
238
+ " self.tokenizer = tokenizer\n",
239
+ " self.max_len = max_len\n",
240
+ "\n",
241
+ " def __len__(self):\n",
242
+ " return len(self.dataset)\n",
243
+ "\n",
244
+ " def __getitem__(self, idx):\n",
245
+ " item = self.dataset[idx]\n",
246
+ " premise = item['premise']\n",
247
+ " hypothesis = item['hypothesis']\n",
248
+ " label = item['label']\n",
249
+ "\n",
250
+ " encoded_premise = self.tokenizer(\n",
251
+ " premise,\n",
252
+ " add_special_tokens=True,\n",
253
+ " max_length=self.max_len,\n",
254
+ " padding='max_length',\n",
255
+ " return_attention_mask=True,\n",
256
+ " truncation=True\n",
257
+ " )\n",
258
+ "\n",
259
+ " encoded_hypothesis = self.tokenizer(\n",
260
+ " hypothesis,\n",
261
+ " add_special_tokens=True,\n",
262
+ " max_length=self.max_len,\n",
263
+ " padding='max_length',\n",
264
+ " return_attention_mask=True,\n",
265
+ " truncation=True\n",
266
+ " )\n",
267
+ "\n",
268
+ " return {\n",
269
+ " 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
270
+ " 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
271
+ " 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
272
+ " 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
273
+ " 'label': torch.tensor(label, dtype=torch.long)\n",
274
+ " }\n",
275
+ "\n",
276
+ "train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
277
+ "test_loader = DataLoader(NLIDataset(test_dataset, tokenizer), batch_size=16, shuffle=False)"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": 13,
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "# S-BERT Model\n",
287
+ "class SBERT(nn.Module):\n",
288
+ " def __init__(self, bert_model):\n",
289
+ " super(SBERT, self).__init__()\n",
290
+ " self.bert = bert_model\n",
291
+ " self.classifier = nn.Linear(d_model * 3, 3)\n",
292
+ "\n",
293
+ " def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
294
+ " device = premise_ids.device\n",
295
+ " dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
296
+ " \n",
297
+ " _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
298
+ " mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
299
+ " u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
300
+ "\n",
301
+ " _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
302
+ " mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
303
+ " v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
304
+ "\n",
305
+ " uv_abs = torch.abs(u - v)\n",
306
+ " features = torch.cat([u, v, uv_abs], dim=-1)\n",
307
+ " logits = self.classifier(features)\n",
308
+ " return logits\n",
309
+ "\n",
310
+ "sbert = SBERT(bert).to(device)\n",
311
+ "optimizer = optim.Adam(sbert.parameters(), lr=2e-5)\n",
312
+ "criterion = nn.CrossEntropyLoss()"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 14,
318
+ "metadata": {},
319
+ "outputs": [
320
+ {
321
+ "name": "stdout",
322
+ "output_type": "stream",
323
+ "text": [
324
+ "Starting Training...\n",
325
+ "Epoch 1 Loss: 1.0723\n",
326
+ "Epoch 2 Loss: 1.0245\n",
327
+ "Epoch 3 Loss: 0.9913\n",
328
+ "Epoch 4 Loss: 0.9702\n",
329
+ "Epoch 5 Loss: 0.9487\n",
330
+ "Epoch 6 Loss: 0.9257\n",
331
+ "Epoch 7 Loss: 0.9018\n",
332
+ "Epoch 8 Loss: 0.8794\n",
333
+ "Epoch 9 Loss: 0.8573\n",
334
+ "Epoch 10 Loss: 0.8355\n",
335
+ "Epoch 11 Loss: 0.8129\n",
336
+ "Epoch 12 Loss: 0.7907\n",
337
+ "Epoch 13 Loss: 0.7694\n",
338
+ "Epoch 14 Loss: 0.7470\n",
339
+ "Epoch 15 Loss: 0.7227\n",
340
+ "Epoch 16 Loss: 0.7023\n",
341
+ "Epoch 17 Loss: 0.6780\n",
342
+ "Epoch 18 Loss: 0.6569\n",
343
+ "Epoch 19 Loss: 0.6336\n",
344
+ "Epoch 20 Loss: 0.6084\n",
345
+ "Epoch 21 Loss: 0.5883\n",
346
+ "Epoch 22 Loss: 0.5596\n",
347
+ "Epoch 23 Loss: 0.5356\n",
348
+ "Epoch 24 Loss: 0.5116\n",
349
+ "Epoch 25 Loss: 0.4880\n",
350
+ "Epoch 26 Loss: 0.4623\n",
351
+ "Epoch 27 Loss: 0.4392\n",
352
+ "Epoch 28 Loss: 0.4161\n",
353
+ "Epoch 29 Loss: 0.3903\n",
354
+ "Epoch 30 Loss: 0.3692\n",
355
+ "Epoch 31 Loss: 0.3509\n",
356
+ "Epoch 32 Loss: 0.3258\n",
357
+ "Epoch 33 Loss: 0.3048\n",
358
+ "Epoch 34 Loss: 0.2834\n",
359
+ "Epoch 35 Loss: 0.2664\n",
360
+ "Epoch 36 Loss: 0.2493\n",
361
+ "Epoch 37 Loss: 0.2327\n",
362
+ "Epoch 38 Loss: 0.2145\n",
363
+ "Epoch 39 Loss: 0.2049\n",
364
+ "Epoch 40 Loss: 0.1845\n",
365
+ "Epoch 41 Loss: 0.1687\n",
366
+ "Epoch 42 Loss: 0.1627\n",
367
+ "Epoch 43 Loss: 0.1548\n",
368
+ "Epoch 44 Loss: 0.1367\n",
369
+ "Epoch 45 Loss: 0.1268\n",
370
+ "Epoch 46 Loss: 0.1315\n",
371
+ "Epoch 47 Loss: 0.1230\n",
372
+ "Epoch 48 Loss: 0.1051\n",
373
+ "Epoch 49 Loss: 0.0964\n",
374
+ "Epoch 50 Loss: 0.1027\n",
375
+ "Epoch 51 Loss: 0.0983\n",
376
+ "Epoch 52 Loss: 0.0781\n",
377
+ "Epoch 53 Loss: 0.0795\n",
378
+ "Epoch 54 Loss: 0.0860\n",
379
+ "Epoch 55 Loss: 0.0800\n",
380
+ "Epoch 56 Loss: 0.0620\n",
381
+ "Epoch 57 Loss: 0.0905\n",
382
+ "Epoch 58 Loss: 0.0567\n",
383
+ "Epoch 59 Loss: 0.0568\n",
384
+ "Epoch 60 Loss: 0.0502\n",
385
+ "Epoch 61 Loss: 0.0808\n",
386
+ "Epoch 62 Loss: 0.0622\n",
387
+ "Epoch 63 Loss: 0.0445\n",
388
+ "Epoch 64 Loss: 0.0536\n",
389
+ "Epoch 65 Loss: 0.0564\n",
390
+ "Epoch 66 Loss: 0.0542\n",
391
+ "Epoch 67 Loss: 0.0537\n",
392
+ "Epoch 68 Loss: 0.0419\n",
393
+ "Epoch 69 Loss: 0.0648\n",
394
+ "Epoch 70 Loss: 0.0496\n",
395
+ "Epoch 71 Loss: 0.0510\n",
396
+ "Epoch 72 Loss: 0.0470\n",
397
+ "Epoch 73 Loss: 0.0446\n",
398
+ "Epoch 74 Loss: 0.0359\n",
399
+ "Epoch 75 Loss: 0.0533\n",
400
+ "Epoch 76 Loss: 0.0611\n",
401
+ "Epoch 77 Loss: 0.0368\n",
402
+ "Epoch 78 Loss: 0.0291\n",
403
+ "Epoch 79 Loss: 0.0321\n",
404
+ "Epoch 80 Loss: 0.0757\n",
405
+ "Epoch 81 Loss: 0.0546\n",
406
+ "Epoch 82 Loss: 0.0300\n",
407
+ "Epoch 83 Loss: 0.0279\n",
408
+ "Epoch 84 Loss: 0.0294\n",
409
+ "Epoch 85 Loss: 0.0542\n",
410
+ "Epoch 86 Loss: 0.0422\n",
411
+ "Epoch 87 Loss: 0.0353\n",
412
+ "Epoch 88 Loss: 0.0537\n",
413
+ "Epoch 89 Loss: 0.0300\n",
414
+ "Epoch 90 Loss: 0.0295\n",
415
+ "Epoch 91 Loss: 0.0422\n",
416
+ "Epoch 92 Loss: 0.0403\n",
417
+ "Epoch 93 Loss: 0.0225\n",
418
+ "Epoch 94 Loss: 0.0335\n",
419
+ "Epoch 95 Loss: 0.0457\n",
420
+ "Epoch 96 Loss: 0.0307\n",
421
+ "Epoch 97 Loss: 0.0253\n",
422
+ "Epoch 98 Loss: 0.0543\n",
423
+ "Epoch 99 Loss: 0.0302\n",
424
+ "Epoch 100 Loss: 0.0237\n",
425
+ "Epoch 101 Loss: 0.0344\n",
426
+ "Epoch 102 Loss: 0.0417\n",
427
+ "Epoch 103 Loss: 0.0227\n",
428
+ "Epoch 104 Loss: 0.0267\n",
429
+ "Epoch 105 Loss: 0.0431\n",
430
+ "Epoch 106 Loss: 0.0263\n",
431
+ "Epoch 107 Loss: 0.0442\n",
432
+ "Epoch 108 Loss: 0.0300\n",
433
+ "Epoch 109 Loss: 0.0215\n",
434
+ "Epoch 110 Loss: 0.0262\n",
435
+ "Epoch 111 Loss: 0.0485\n",
436
+ "Epoch 112 Loss: 0.0253\n",
437
+ "Epoch 113 Loss: 0.0202\n",
438
+ "Epoch 114 Loss: 0.0226\n",
439
+ "Epoch 115 Loss: 0.0355\n",
440
+ "Epoch 116 Loss: 0.0534\n",
441
+ "Epoch 117 Loss: 0.0210\n",
442
+ "Epoch 118 Loss: 0.0173\n",
443
+ "Epoch 119 Loss: 0.0315\n",
444
+ "Epoch 120 Loss: 0.0457\n",
445
+ "Epoch 121 Loss: 0.0209\n",
446
+ "Epoch 122 Loss: 0.0226\n",
447
+ "Epoch 123 Loss: 0.0325\n",
448
+ "Epoch 124 Loss: 0.0320\n",
449
+ "Epoch 125 Loss: 0.0269\n",
450
+ "Epoch 126 Loss: 0.0212\n",
451
+ "Epoch 127 Loss: 0.0213\n",
452
+ "Epoch 128 Loss: 0.0313\n",
453
+ "Epoch 129 Loss: 0.0376\n",
454
+ "Epoch 130 Loss: 0.0284\n",
455
+ "Epoch 131 Loss: 0.0177\n",
456
+ "Epoch 132 Loss: 0.0172\n",
457
+ "Epoch 133 Loss: 0.0234\n",
458
+ "Epoch 134 Loss: 0.0442\n",
459
+ "Epoch 135 Loss: 0.0222\n",
460
+ "Epoch 136 Loss: 0.0293\n",
461
+ "Epoch 137 Loss: 0.0258\n",
462
+ "Epoch 138 Loss: 0.0260\n",
463
+ "Epoch 139 Loss: 0.0220\n",
464
+ "Epoch 140 Loss: 0.0167\n",
465
+ "Epoch 141 Loss: 0.0395\n",
466
+ "Epoch 142 Loss: 0.0265\n",
467
+ "Epoch 143 Loss: 0.0179\n",
468
+ "Epoch 144 Loss: 0.0195\n",
469
+ "Epoch 145 Loss: 0.0318\n",
470
+ "Epoch 146 Loss: 0.0224\n",
471
+ "Epoch 147 Loss: 0.0160\n",
472
+ "Epoch 148 Loss: 0.0215\n",
473
+ "Epoch 149 Loss: 0.0491\n",
474
+ "Epoch 150 Loss: 0.0197\n",
475
+ "Epoch 151 Loss: 0.0203\n",
476
+ "Epoch 152 Loss: 0.0238\n",
477
+ "Epoch 153 Loss: 0.0260\n",
478
+ "Epoch 154 Loss: 0.0178\n",
479
+ "Epoch 155 Loss: 0.0156\n",
480
+ "Epoch 156 Loss: 0.0171\n",
481
+ "Epoch 157 Loss: 0.0243\n",
482
+ "Epoch 158 Loss: 0.0403\n",
483
+ "Epoch 159 Loss: 0.0180\n",
484
+ "Epoch 160 Loss: 0.0172\n",
485
+ "Epoch 161 Loss: 0.0198\n",
486
+ "Epoch 162 Loss: 0.0336\n",
487
+ "Epoch 163 Loss: 0.0222\n",
488
+ "Epoch 164 Loss: 0.0155\n",
489
+ "Epoch 165 Loss: 0.0193\n",
490
+ "Epoch 166 Loss: 0.0239\n",
491
+ "Epoch 167 Loss: 0.0183\n",
492
+ "Epoch 168 Loss: 0.0160\n",
493
+ "Epoch 169 Loss: 0.0182\n",
494
+ "Epoch 170 Loss: 0.0389\n",
495
+ "Epoch 171 Loss: 0.0229\n",
496
+ "Epoch 172 Loss: 0.0171\n",
497
+ "Epoch 173 Loss: 0.0162\n",
498
+ "Epoch 174 Loss: 0.0206\n",
499
+ "Epoch 175 Loss: 0.0159\n",
500
+ "Epoch 176 Loss: 0.0158\n",
501
+ "Epoch 177 Loss: 0.0361\n",
502
+ "Epoch 178 Loss: 0.0346\n",
503
+ "Epoch 179 Loss: 0.0183\n",
504
+ "Epoch 180 Loss: 0.0163\n",
505
+ "Epoch 181 Loss: 0.0132\n",
506
+ "Epoch 182 Loss: 0.0162\n",
507
+ "Epoch 183 Loss: 0.0160\n",
508
+ "Epoch 184 Loss: 0.0348\n",
509
+ "Epoch 185 Loss: 0.0271\n",
510
+ "Epoch 186 Loss: 0.0168\n",
511
+ "Epoch 187 Loss: 0.0129\n",
512
+ "Epoch 188 Loss: 0.0138\n",
513
+ "Epoch 189 Loss: 0.0161\n",
514
+ "Epoch 190 Loss: 0.0218\n",
515
+ "Epoch 191 Loss: 0.0254\n",
516
+ "Epoch 192 Loss: 0.0254\n",
517
+ "Epoch 193 Loss: 0.0127\n",
518
+ "Epoch 194 Loss: 0.0126\n",
519
+ "Epoch 195 Loss: 0.0151\n",
520
+ "Epoch 196 Loss: 0.0197\n",
521
+ "Epoch 197 Loss: 0.0283\n",
522
+ "Epoch 198 Loss: 0.0157\n",
523
+ "Epoch 199 Loss: 0.0134\n",
524
+ "Epoch 200 Loss: 0.0165\n",
525
+ "Done!\n"
526
+ ]
527
+ }
528
+ ],
529
+ "source": [
530
+ "# Train Loop\n",
531
+ "print(\"Starting Training...\")\n",
532
+ "epochs = 200\n",
533
+ "for epoch in range(epochs):\n",
534
+ " sbert.train()\n",
535
+ " total_loss = 0\n",
536
+ " for batch in train_loader:\n",
537
+ " p_ids = batch['premise_input_ids'].to(device)\n",
538
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
539
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
540
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
541
+ " labels = batch['label'].to(device)\n",
542
+ "\n",
543
+ " optimizer.zero_grad()\n",
544
+ " logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
545
+ " loss = criterion(logits, labels)\n",
546
+ " loss.backward()\n",
547
+ " optimizer.step()\n",
548
+ " total_loss += loss.item()\n",
549
+ " print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
550
+ "\n",
551
+ "print(\"Done!\")\n",
552
+ "torch.save(sbert.state_dict(), f'./models/sbert_{DATASET_NAME}.pt')"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "markdown",
557
+ "metadata": {},
558
+ "source": [
559
+ "## 4. Evaluation\n",
560
+ "\n",
561
+ "Evaluate on specific test set."
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": 15,
567
+ "metadata": {},
568
+ "outputs": [
569
+ {
570
+ "name": "stdout",
571
+ "output_type": "stream",
572
+ "text": [
573
+ "Evaluating...\n",
574
+ "Classification Report:\n",
575
+ " precision recall f1-score support\n",
576
+ "\n",
577
+ " Entailment 0.53 0.56 0.54 339\n",
578
+ " Neutral 0.51 0.44 0.47 324\n",
579
+ "Contradiction 0.45 0.49 0.47 325\n",
580
+ "\n",
581
+ " accuracy 0.49 988\n",
582
+ " macro avg 0.50 0.49 0.49 988\n",
583
+ " weighted avg 0.50 0.49 0.49 988\n",
584
+ "\n",
585
+ "Accuracy: 0.4949\n"
586
+ ]
587
+ }
588
+ ],
589
+ "source": [
590
+ "sbert.eval()\n",
591
+ "all_preds = []\n",
592
+ "all_labels = []\n",
593
+ "\n",
594
+ "print(\"Evaluating...\")\n",
595
+ "with torch.no_grad():\n",
596
+ " for batch in test_loader:\n",
597
+ " p_ids = batch['premise_input_ids'].to(device)\n",
598
+ " p_seg = batch['premise_segment_ids'].to(device)\n",
599
+ " h_ids = batch['hypothesis_input_ids'].to(device)\n",
600
+ " h_seg = batch['hypothesis_segment_ids'].to(device)\n",
601
+ " labels = batch['label'].to(device)\n",
602
+ "\n",
603
+ " logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
604
+ " preds = torch.argmax(logits, dim=1)\n",
605
+ " \n",
606
+ " all_preds.extend(preds.cpu().numpy())\n",
607
+ " all_labels.extend(labels.cpu().numpy())\n",
608
+ "\n",
609
+ "target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
610
+ "print(\"Classification Report:\")\n",
611
+ "print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
612
+ "print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
613
+ ]
614
+ }
615
+ ],
616
+ "metadata": {
617
+ "kernelspec": {
618
+ "display_name": "Python 3",
619
+ "language": "python",
620
+ "name": "python3"
621
+ },
622
+ "language_info": {
623
+ "codemirror_mode": {
624
+ "name": "ipython",
625
+ "version": 3
626
+ },
627
+ "file_extension": ".py",
628
+ "mimetype": "text/x-python",
629
+ "name": "python",
630
+ "nbconvert_exporter": "python",
631
+ "pygments_lexer": "ipython3",
632
+ "version": "3.8.5"
633
+ }
634
+ },
635
+ "nbformat": 4,
636
+ "nbformat_minor": 4
637
+ }
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "app.app:app"]
README.md CHANGED
@@ -1,10 +1,87 @@
1
  ---
2
- title: A4 NLI App
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: blue
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: NLI Text Similarity App
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ app_port: 8000
8
  ---
9
 
10
+ # NLI Text Similarity App (A4 Assignment)
11
+
12
+ **Name:** HTUT KO KO
13
+ **ID:** st126010
14
+
15
+ I implemented a Mini-BERT model from scratch and fine-tuned it as a Sentence-BERT (S-BERT) model for Natural Language Inference (NLI) tasks. This project includes a modern web application for real-time similarity analysis.
16
+
17
+ ## Project Structure
18
+
19
+ - `A4_BERT.ipynb`: Task 1 - I pre-trained BERT from scratch on WikiText-103.
20
+ - `A4_Climate_FEVER.ipynb`: Task 2 - I fine-tuned S-BERT on the Climate-FEVER dataset.
21
+ - `A4_Option_SNLI.ipynb`: Alternative training notebook where I trained on the SNLI dataset.
22
+ - `A4_Option_MNLI.ipynb`: Focused notebook where I trained on the MNLI dataset.
23
+ - `app/`: My Flask web application components.
24
+ - `models/`: Squared model weights (`bert_trained.pt`, `sbert_climate_fever.pt`, `sbert_snli.pt`, `sbert_mnli.pt`).
25
+
26
+ ## Final Results
27
+
28
+ I trained the models on three different datasets. Here are the results I achieved:
29
+
30
+ | Dataset | Epochs | Accuracy | Loss |
31
+ | :--- | :--- | :--- | :--- |
32
+ | **Climate-FEVER** | 200 | **50.5%** | 0.0000 |
33
+ | **SNLI** | 200 | **50.8%** | ~0.59 |
34
+ | **MNLI** | 100 | **41.5%** | 0.0000 |
35
+
36
+ ### detailed Climate-FEVER Metrics
37
+
38
+ | Class | Precision | Recall | F1-Score |
39
+ | :--- | :--- | :--- | :--- |
40
+ | Entailment | 0.33 | 0.28 | 0.30 |
41
+ | Neutral | 0.62 | 0.68 | 0.65 |
42
+ | Contradiction | 0.40 | 0.55 | 0.46 |
43
+
44
+ ## Limitations & Analysis
45
+
46
+ ### 1. Vocabulary Size
47
+ I limited the vocabulary size to **5004** (compared to standard BERT's 30,522) to ensure the model could be trained precisely on the smaller WikiText-103 subset. While this improved convergence for this assignment, it restricts the model's ability to understand rare words outside this vocabulary.
48
+
49
+ ### 2. Tokenizer Mismatch
50
+ A challenge I encountered was using the standard `BertTokenizer` with my custom Mini-BERT. The tokenizer produces IDs > 5004, which caused `IndexError` in the web app. I resolved this by implementing a clamping mechanism in `app.py` to map unknown tokens to the `[UNK]` ID.
51
+
52
+ ### 3. Model Depth
53
+ I used a "Mini-BERT" configuration (`n_layers=2`, `d_model=256`) instead of the base (`n_layers=12`, `d_model=768`). This trade-off significantly reduced training time but naturally limits the model's capacity to capture complex linguistic nuances compared to the full BERT-Base.
54
+
55
+ ## Demonstration
56
+
57
+ ![WebUI](demo.gif)
58
+
59
+ ## How to Run
60
+
61
+ ### 1. Setup Environment
62
+
63
+ ```bash
64
+ pip install -r requirements.txt
65
+ ```
66
+
67
+ ### 2. Run the Web App
68
+
69
+ ```bash
70
+ python app/app.py
71
+ ```
72
+
73
+ Access the app at `http://127.0.0.1:8000`.
74
+
75
+ ## Features
76
+
77
+ - **Modern UI**: I designed a Glassmorphism theme with a dynamic background.
78
+ - **Multi-Model Support**: Users can select between Climate-FEVER, SNLI, or MNLI trained models.
79
+ - **Explainable AI**: The app displays the probability distribution for each prediction.
80
+
81
+ ## References
82
+
83
+ 1. **BERT / WikiText-103**: Merity, S., Xiong, C., Bradbury, J., & Socher, R. (2016). *Pointer Sentinel Mixture Models*.
84
+ 2. **S-BERT (Sentence-BERT)**: Reimers, N., & Gurevych, I. (2019). *Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks*.
85
+ 3. **Climate-FEVER**: Diggelmann, T., Boyd-Graber, J., Bulian, J., Ciaramita, M., & Leippold, M. (2020). *CLIMATE-FEVER: A Dataset for Verification of Real-World Climate Claims*.
86
+ 4. **SNLI**: Bowman, S. R., Angeli, G., Potts, C., & Manning, C. D. (2015). *A large annotated corpus for learning natural language inference*.
87
+ 5. **MNLI**: Williams, A., Nangia, N., & Bowman, S. R. (2018). *A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference*.
app/app.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from flask import Flask, render_template, request, jsonify
5
+ from transformers import BertTokenizer
6
+ import os
7
+ import math
8
+ import numpy as np
9
+
10
+ app = Flask(__name__)
11
+
12
+ # --- Configuration ---
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ MODEL_PATH = "../models/sbert_climate_fever.pt" # Path relative to app/ directory execution usually
15
+ # But we will run from project root or handle paths carefully
16
+ # Let's assume running from project root: python app/app.py
17
+ # Then path is models/sbert_climate_fever.pt
18
+ MODEL_PATH_REL = "models/sbert_climate_fever.pt"
19
+
20
+ # --- Model Definitions (Must match training Code) ---
21
+ # Copied from A4_Solution.ipynb
22
+
23
+ n_layers = 2
24
+ n_heads = 4
25
+ d_model = 256
26
+ d_ff = 256 * 4
27
+ d_k = d_v = 64
28
+ n_segments = 2
29
+ max_len = 128
30
+ vocab_size = 5004 # Custom vocab size from training
31
+
32
+ class Embedding(nn.Module):
33
+ def __init__(self):
34
+ super(Embedding, self).__init__()
35
+ self.tok_embed = nn.Embedding(vocab_size, d_model) # token embedding
36
+ self.pos_embed = nn.Embedding(max_len, d_model) # position embedding
37
+ self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding
38
+ self.norm = nn.LayerNorm(d_model)
39
+ # Initialize weights to avoid large initial loss
40
+ self.tok_embed.weight.data.normal_(0, 0.1)
41
+ self.pos_embed.weight.data.normal_(0, 0.1)
42
+ self.seg_embed.weight.data.normal_(0, 0.1)
43
+
44
+ def forward(self, x, seg):
45
+ seq_len = x.size(1)
46
+ pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
47
+ pos = pos.unsqueeze(0).expand_as(x) # (len,) -> (bs, len)
48
+ embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
49
+ return self.norm(embedding)
50
+
51
+ def get_attn_pad_mask(seq_q, seq_k):
52
+ batch_size, len_q = seq_q.size()
53
+ batch_size, len_k = seq_k.size()
54
+ # eq(zero) is PAD token
55
+ pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
56
+ return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
57
+
58
+ class ScaledDotProductAttention(nn.Module):
59
+ def __init__(self):
60
+ super(ScaledDotProductAttention, self).__init__()
61
+
62
+ def forward(self, Q, K, V, attn_mask):
63
+ scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
64
+ scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
65
+ attn = nn.Softmax(dim=-1)(scores)
66
+ context = torch.matmul(attn, V)
67
+ return context, attn
68
+
69
+ class MultiHeadAttention(nn.Module):
70
+ def __init__(self):
71
+ super(MultiHeadAttention, self).__init__()
72
+ self.W_Q = nn.Linear(d_model, d_k * n_heads)
73
+ self.W_K = nn.Linear(d_model, d_k * n_heads)
74
+ self.W_V = nn.Linear(d_model, d_v * n_heads)
75
+ self.linear = nn.Linear(n_heads * d_v, d_model) # Defined in init
76
+ self.layer_norm = nn.LayerNorm(d_model) # Defined in init
77
+
78
+ def forward(self, Q, K, V, attn_mask):
79
+ # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
80
+ residual, batch_size = Q, Q.size(0)
81
+ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
82
+ q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k]
83
+ k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k]
84
+ v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v]
85
+
86
+ attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
87
+
88
+ # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
89
+ context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
90
+ context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
91
+ output = self.linear(context)
92
+ return self.layer_norm(output + residual), attn # output: [batch_size x len_q x d_model]
93
+
94
+ class PoswiseFeedForwardNet(nn.Module):
95
+ def __init__(self):
96
+ super(PoswiseFeedForwardNet, self).__init__()
97
+ self.fc1 = nn.Linear(d_model, d_ff)
98
+ self.fc2 = nn.Linear(d_ff, d_model)
99
+
100
+ def forward(self, x):
101
+ # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
102
+ return self.fc2(F.gelu(self.fc1(x)))
103
+
104
+ class EncoderLayer(nn.Module):
105
+ def __init__(self):
106
+ super(EncoderLayer, self).__init__()
107
+ self.enc_self_attn = MultiHeadAttention()
108
+ self.pos_ffn = PoswiseFeedForwardNet()
109
+
110
+ def forward(self, enc_inputs, enc_self_attn_mask):
111
+ enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
112
+ enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
113
+ return enc_outputs, attn
114
+
115
+ class BERT(nn.Module):
116
+ def __init__(self):
117
+ super(BERT, self).__init__()
118
+ self.embedding = Embedding()
119
+ self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
120
+ self.fc = nn.Linear(d_model, d_model)
121
+ self.activ = nn.Tanh()
122
+ self.linear = nn.Linear(d_model, d_model)
123
+ self.norm = nn.LayerNorm(d_model)
124
+
125
+ self.classifier = nn.Linear(d_model, 2)
126
+ # decoder is shared with embedding layer
127
+ embed_weight = self.embedding.tok_embed.weight
128
+ n_vocab, n_dim = embed_weight.size()
129
+ self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
130
+ self.decoder.weight = embed_weight
131
+ self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
132
+
133
+ def forward(self, input_ids, segment_ids, masked_pos=None):
134
+ # NOTE: masked_pos is optional here because for S-BERT we only need 'output'
135
+ # But to be consistent with NLI/Notebook forward pass, we handle it if provided
136
+ # or just run through.
137
+
138
+ output = self.embedding(input_ids, segment_ids)
139
+ enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
140
+ for layer in self.layers:
141
+ output, enc_self_attn = layer(output, enc_self_attn_mask)
142
+ # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
143
+
144
+ # 1. predict next sentence
145
+ # it will be decided by first token(CLS)
146
+ h_pooled = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
147
+ logits_nsp = self.classifier(h_pooled) # [batch_size, 2]
148
+
149
+ # 2. predict the masked token
150
+ if masked_pos is not None:
151
+ masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
152
+ h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
153
+ h_masked = self.norm(F.gelu(self.linear(h_masked)))
154
+ logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]
155
+ return logits_lm, logits_nsp, output
156
+ else:
157
+ return None, logits_nsp, output # S-BERT inference only needs output
158
+
159
+ class SBERT(nn.Module):
160
+ def __init__(self, bert_model):
161
+ super(SBERT, self).__init__()
162
+ self.bert = bert_model
163
+ # 3 * d_model because we concat u, v, |u-v|
164
+ self.classifier = nn.Linear(d_model * 3, 3)
165
+
166
+ def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):
167
+ # Make dummy masked_pos for BERT forward (it's not used for encoding really, but input requires it)
168
+ # Creating a dummy masked_pos of shape [batch_size, 1] filled with 0
169
+ dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(premise_ids.device)
170
+
171
+ # Encode Premise (u)
172
+ _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)
173
+ # Mean Pooling
174
+ mask_u = (premise_ids != 0).unsqueeze(-1).float() # [batch, len, 1]
175
+ u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)
176
+
177
+ # Encode Hypothesis (v)
178
+ _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)
179
+ mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()
180
+ v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)
181
+
182
+ # Classifier: concatenate u, v, |u-v|
183
+ uv_abs = torch.abs(u - v)
184
+ features = torch.cat([u, v, uv_abs], dim=-1)
185
+
186
+ logits = self.classifier(features)
187
+ return logits, u, v # returning u, v for cosine sim later if needed
188
+
189
+ # --- Model Management ---
190
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
191
+ models = {}
192
+ MODEL_FILES = {
193
+ 'climate_fever': 'models/sbert_climate_fever.pt',
194
+ 'snli': 'models/sbert_snli.pt',
195
+ 'mnli': 'models/sbert_mnli.pt'
196
+ }
197
+
198
+ def get_model(model_name):
199
+ # Handle custom input scenario
200
+ if model_name == 'custom':
201
+ model_name = 'climate_fever'
202
+
203
+ # Load model on demand or return cached
204
+ if model_name in models:
205
+ return models[model_name]
206
+
207
+ # Check if model name is known
208
+ if model_name not in MODEL_FILES:
209
+ print(f"Warning: Unknown model name '{model_name}'.")
210
+ return None
211
+
212
+ rel_path = MODEL_FILES[model_name]
213
+ path = f"../{rel_path}"
214
+
215
+ if not os.path.exists(path):
216
+ # Fallback to local path if running from app folder
217
+ path = rel_path
218
+
219
+ if not os.path.exists(path):
220
+ print(f"Model file not found at {path}")
221
+ return None
222
+
223
+ print(f"Loading {model_name} from {path}...")
224
+ try:
225
+ bert = BERT()
226
+ model = SBERT(bert)
227
+ state_dict = torch.load(path, map_location=DEVICE)
228
+ model.load_state_dict(state_dict, strict=False)
229
+ model.to(DEVICE)
230
+ model.eval()
231
+ models[model_name] = model
232
+ return model
233
+ except Exception as e:
234
+ print(f"Failed to load {model_name}: {e}")
235
+ return None
236
+
237
+ # Pre-load default
238
+ get_model('climate_fever')
239
+
240
+ @app.route('/')
241
+ def home():
242
+ return render_template('index.html')
243
+
244
+ @app.route('/predict', methods=['POST'])
245
+ def predict():
246
+ data = request.json
247
+ sentence1 = data.get('sentence1', '')
248
+ sentence2 = data.get('sentence2', '')
249
+ model_type = data.get('model_type', 'climate_fever') # Default
250
+
251
+ if not sentence1 or not sentence2:
252
+ return jsonify({'error': 'Both sentences are required'}), 400
253
+
254
+ # Get specific model
255
+ model = get_model(model_type)
256
+
257
+ if model is None:
258
+ # Fallback to whatever is loaded or error
259
+ if models:
260
+ model = list(models.values())[0]
261
+ print(f"Warning: Requested {model_type} not found, using fallback.")
262
+ else:
263
+ return jsonify({'error': f'Model {model_type} not trained/found. Please train it first!'}), 404
264
+
265
+ # Tokenize
266
+ # Tokenize
267
+ inputs_a = tokenizer(sentence1, max_length=128, truncation=True, padding='max_length')
268
+ inputs_b = tokenizer(sentence2, max_length=128, truncation=True, padding='max_length')
269
+
270
+ p_ids = torch.tensor(inputs_a['input_ids']).unsqueeze(0).to(DEVICE)
271
+ p_seg = torch.tensor(inputs_a['token_type_ids']).unsqueeze(0).to(DEVICE)
272
+ h_ids = torch.tensor(inputs_b['input_ids']).unsqueeze(0).to(DEVICE)
273
+ h_seg = torch.tensor(inputs_b['token_type_ids']).unsqueeze(0).to(DEVICE)
274
+
275
+ # Clamp inputs to vocab size (handle OOV from standard tokenizer)
276
+ p_ids[p_ids >= vocab_size] = 1 # [UNK]
277
+ h_ids[h_ids >= vocab_size] = 1 # [UNK]
278
+
279
+ with torch.no_grad():
280
+ logits, u, v = model(p_ids, p_seg, h_ids, h_seg)
281
+ probs = F.softmax(logits, dim=1).cpu().numpy()[0]
282
+
283
+ # Labels: entailment, neutral, contradiction
284
+ # Note: SNLI/MNLI/Climate-FEVER generally follow Entailment(0), Neutral(1), Contradiction(2)
285
+ # BUT check the mapping in notebooks.
286
+ # Climate-Fever: 0:Supports(Entailment), 1:Refutes(Contradiction), 2:NEI(Neutral) -> Re-mapped in NB to 0, 2, 1?
287
+ # Let's check training NB mapping.
288
+ # In A4_Climate_FEVER.ipynb: label_map = {0: 0, 1: 2, 2: 1} -> 0:Entailment, 2:Contradiction, 1:Neutral
289
+ # In standard SNLI/MNLI: 0:Entailment, 1:Neutral, 2:Contradiction
290
+
291
+ # We need to map probs correctly based on model_type
292
+ if model_type == 'climate_fever':
293
+ # Trained with: 0=Entailment, 1=Neutral, 2=Contradiction (Based on my previous fix? Wait, check mapping in NB)
294
+ # NB: label_map = {0: 0, 1: 2, 2: 1} -> This means original 0->0, 1->2, 2->1.
295
+ # So Model Outputs: Class 0=Entailment, Class 1=Neutral, Class 2=Contradiction
296
+ # Wait, if map is {0:0, 1:2, 2:1}, then Evidence Label 0 (Supports) -> Class 0
297
+ # Evidence Label 1 (Refutes) -> Class 2
298
+ # Evidence Label 2 (NEI) -> Class 1
299
+ # So Class indices: 0=Entailment, 1=Neutral, 2=Contradiction.
300
+ labels = ['Entailment', 'Neutral', 'Contradiction']
301
+ else:
302
+ # SNLI/MNLI standard: 0=Entailment, 1=Neutral, 2=Contradiction
303
+ labels = ['Entailment', 'Neutral', 'Contradiction']
304
+
305
+ # Result dict
306
+ result = {label: float(prob) for label, prob in zip(labels, probs)}
307
+ prediction = labels[np.argmax(probs)]
308
+
309
+ return jsonify({
310
+ 'prediction': prediction,
311
+ 'probabilities': result,
312
+ 'used_model': model_type
313
+ })
314
+
315
+ if __name__ == '__main__':
316
+ app.run(debug=True, port=8000)
app/static/style.css ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary: #6366f1;
3
+ --primary-hover: #4f46e5;
4
+ --bg-color: #0f172a;
5
+ --card-bg: rgba(30, 41, 59, 0.7);
6
+ --text-color: #f8fafc;
7
+ --text-muted: #94a3b8;
8
+ --border-color: rgba(255, 255, 255, 0.1);
9
+ }
10
+
11
+ * {
12
+ box-sizing: border-box;
13
+ margin: 0;
14
+ padding: 0;
15
+ }
16
+
17
+ body {
18
+ font-family: 'Inter', sans-serif;
19
+ background-color: var(--bg-color);
20
+ color: var(--text-color);
21
+ min-height: 100vh;
22
+ display: flex;
23
+ justify-content: center;
24
+ align-items: center;
25
+ overflow-x: hidden;
26
+ position: relative;
27
+ }
28
+
29
+ /* Ambient Background Effect */
30
+ .background-orb {
31
+ position: fixed;
32
+ top: -20%;
33
+ left: -10%;
34
+ width: 50vw;
35
+ height: 50vw;
36
+ background: radial-gradient(circle, rgba(99, 102, 241, 0.3) 0%, rgba(15, 23, 42, 0) 70%);
37
+ border-radius: 50%;
38
+ z-index: -1;
39
+ animation: float 10s infinite ease-in-out;
40
+ }
41
+
42
+ @keyframes float {
43
+ 0%, 100% { transform: translate(0, 0); }
44
+ 50% { transform: translate(20px, 30px); }
45
+ }
46
+
47
+ .container {
48
+ width: 100%;
49
+ max-width: 800px;
50
+ padding: 2rem;
51
+ }
52
+
53
+ header {
54
+ text-align: center;
55
+ margin-bottom: 3rem;
56
+ }
57
+
58
+ header h1 {
59
+ font-size: 2.5rem;
60
+ font-weight: 700;
61
+ background: linear-gradient(135deg, #818cf8, #c084fc);
62
+ -webkit-background-clip: text;
63
+ -webkit-text-fill-color: transparent;
64
+ margin-bottom: 0.5rem;
65
+ }
66
+
67
+ .subtitle {
68
+ color: var(--text-muted);
69
+ font-size: 1.1rem;
70
+ }
71
+
72
+ .explanation-box {
73
+ margin-top: 1.5rem;
74
+ background-color: rgba(51, 65, 85, 0.5);
75
+ border: 1px solid var(--border-color);
76
+ padding: 1rem;
77
+ border-radius: 0.75rem;
78
+ text-align: left;
79
+ font-size: 0.9rem;
80
+ color: var(--text-muted);
81
+ }
82
+ .explanation-box h3 {
83
+ color: var(--text-color);
84
+ margin-bottom: 0.5rem;
85
+ font-size: 1rem;
86
+ }
87
+
88
+ main {
89
+ background: var(--card-bg);
90
+ backdrop-filter: blur(12px);
91
+ -webkit-backdrop-filter: blur(12px);
92
+ border: 1px solid var(--border-color);
93
+ border-radius: 1.5rem;
94
+ padding: 2rem;
95
+ box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
96
+ }
97
+
98
+ .control-panel {
99
+ margin-bottom: 1.5rem;
100
+ }
101
+ .control-panel label {
102
+ display: block;
103
+ margin-bottom: 0.5rem;
104
+ color: var(--text-muted);
105
+ font-size: 0.9rem;
106
+ }
107
+ .control-panel select {
108
+ width: 100%;
109
+ padding: 0.75rem;
110
+ border-radius: 0.5rem;
111
+ border: 1px solid var(--border-color);
112
+ background-color: rgba(15, 23, 42, 0.5);
113
+ color: var(--text-color);
114
+ font-size: 1rem;
115
+ outline: none;
116
+ cursor: pointer;
117
+ transition: border-color 0.2s;
118
+ }
119
+ .control-panel select:focus {
120
+ border-color: var(--primary);
121
+ }
122
+
123
+ .input-group {
124
+ display: grid;
125
+ gap: 1.5rem;
126
+ margin-bottom: 2rem;
127
+ }
128
+
129
+ .input-card label {
130
+ display: block;
131
+ margin-bottom: 0.5rem;
132
+ font-weight: 600;
133
+ color: var(--text-muted);
134
+ }
135
+
136
+ .input-wrapper input {
137
+ width: 100%;
138
+ padding: 1rem;
139
+ border-radius: 0.75rem;
140
+ border: 1px solid var(--border-color);
141
+ background-color: rgba(15, 23, 42, 0.5);
142
+ color: var(--text-color);
143
+ font-size: 1rem;
144
+ transition: all 0.2s;
145
+ }
146
+
147
+ .input-wrapper input:focus {
148
+ outline: none;
149
+ border-color: var(--primary);
150
+ box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
151
+ }
152
+
153
+ button#analyze-btn {
154
+ width: 100%;
155
+ padding: 1rem;
156
+ border: none;
157
+ border-radius: 0.75rem;
158
+ background-color: var(--primary);
159
+ color: white;
160
+ font-size: 1.1rem;
161
+ font-weight: 600;
162
+ cursor: pointer;
163
+ transition: background-color 0.2s, transform 0.1s;
164
+ }
165
+
166
+ button#analyze-btn:hover {
167
+ background-color: var(--primary-hover);
168
+ }
169
+ button#analyze-btn:active {
170
+ transform: scale(0.98);
171
+ }
172
+
173
+ .result-card {
174
+ margin-top: 2rem;
175
+ padding-top: 2rem;
176
+ border-top: 1px solid var(--border-color);
177
+ animation: slideUp 0.3s ease-out;
178
+ }
179
+ .hidden {
180
+ display: none;
181
+ }
182
+ @keyframes slideUp {
183
+ from { opacity: 0; transform: translateY(10px); }
184
+ to { opacity: 1; transform: translateY(0); }
185
+ }
186
+
187
+ .prediction-header {
188
+ text-align: center;
189
+ font-size: 1.5rem;
190
+ font-weight: 700;
191
+ margin-bottom: 1.5rem;
192
+ }
193
+
194
+ .prob-bar {
195
+ display: flex;
196
+ align-items: center;
197
+ margin-bottom: 0.75rem;
198
+ gap: 1rem;
199
+ }
200
+
201
+ .prob-label {
202
+ width: 100px;
203
+ font-size: 0.9rem;
204
+ text-align: right;
205
+ color: var(--text-muted);
206
+ }
207
+
208
+ .bar-container {
209
+ flex-grow: 1;
210
+ height: 10px;
211
+ background-color: rgba(255, 255, 255, 0.1);
212
+ border-radius: 10px;
213
+ overflow: hidden;
214
+ }
215
+
216
+ .bar {
217
+ height: 100%;
218
+ background-color: var(--primary);
219
+ border-radius: 10px;
220
+ transition: width 0.5s ease-out;
221
+ }
222
+
223
+ .prob-val {
224
+ width: 50px;
225
+ font-size: 0.9rem;
226
+ font-weight: 600;
227
+ }
228
+
229
+ footer {
230
+ text-align: center;
231
+ margin-top: 3rem;
232
+ color: var(--text-muted);
233
+ font-size: 0.9rem;
234
+ }
app/templates/index.html ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>S-BERT Semantic Similarity Analysis</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
8
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
9
+ </head>
10
+ <body>
11
+ <div class="background-orb"></div>
12
+ <div class="container">
13
+ <header>
14
+ <h1>Semantic Textual Similarity</h1>
15
+ <p class="subtitle">Analyze the relationship between Scientific Claims and Evidence</p>
16
+
17
+ <div class="explanation-box">
18
+ <h3>What is this?</h3>
19
+ <p>
20
+ This tool uses a <strong>Sentence-BERT (S-BERT)</strong> model trained on the <em>Climate-FEVER</em> dataset
21
+ to determine if a Piece of <strong>Evidence</strong> supports, contradicts, or is neutral towards a specific <strong>Claim</strong>.
22
+ Select a dataset scenario below to populate examples, or type your own sentences to test the model's understanding of semantic logic.
23
+ </p>
24
+ </div>
25
+ </header>
26
+
27
+ <main>
28
+ <div class="control-panel">
29
+ <label for="dataset-select">Choose a Dataset / Scenario:</label>
30
+ <select id="dataset-select" onchange="loadScenario()">
31
+ <option value="climate_fever" selected>Climate-FEVER (Science/Climate Only)</option>
32
+ <option value="snli">SNLI (General Knowledge / Logic)</option>
33
+ <option value="mnli">MNLI (Multi-Genre)</option>
34
+ <option value="custom">Custom Input</option>
35
+ </select>
36
+ <p id="scenario-hint" style="font-size: 0.8rem; color: #94a3b8; margin-top: 0.2rem;">
37
+ <em>Tip: 'The Earth is flat' is general logic -> Use <strong>SNLI</strong>. Climate-FEVER is for climate-specific claims.</em>
38
+ </p>
39
+ </div>
40
+
41
+ <div class="input-group">
42
+ <div class="input-card">
43
+ <label for="sentence1">Claim / Sentence 1</label>
44
+ <div class="input-wrapper">
45
+ <input type="text" id="sentence1" list="claims-list" placeholder="Enter a claim or sentence...">
46
+ <datalist id="claims-list">
47
+ <!-- Populated by JS -->
48
+ </datalist>
49
+ </div>
50
+ </div>
51
+
52
+ <div class="input-card">
53
+ <label for="sentence2">Evidence / Sentence 2</label>
54
+ <div class="input-wrapper">
55
+ <input type="text" id="sentence2" list="evidence-list" placeholder="Enter evidence or sentence...">
56
+ <datalist id="evidence-list">
57
+ <!-- Populated by JS -->
58
+ </datalist>
59
+ </div>
60
+ </div>
61
+ </div>
62
+
63
+ <button id="analyze-btn" onclick="predict()">Analyze Similarity</button>
64
+
65
+ <div id="result" class="result-card hidden">
66
+ <div class="prediction-header">
67
+ <span id="prediction-label">Entailment</span>
68
+ </div>
69
+ <div class="probabilities">
70
+ <div class="prob-bar">
71
+ <span class="prob-label">Entailment</span>
72
+ <div class="bar-container"><div class="bar" id="bar-entailment" style="width: 0%"></div></div>
73
+ <span class="prob-val" id="val-entailment">0%</span>
74
+ </div>
75
+ <div class="prob-bar">
76
+ <span class="prob-label">Neutral</span>
77
+ <div class="bar-container"><div class="bar" id="bar-neutral" style="width: 0%"></div></div>
78
+ <span class="prob-val" id="val-neutral">0%</span>
79
+ </div>
80
+ <div class="prob-bar">
81
+ <span class="prob-label">Contradiction</span>
82
+ <div class="bar-container"><div class="bar" id="bar-contradiction" style="width: 0%"></div></div>
83
+ <span class="prob-val" id="val-contradiction">0%</span>
84
+ </div>
85
+ </div>
86
+ </div>
87
+ </main>
88
+
89
+ <footer>
90
+ <p>Developed by <strong>Htut Ko Ko (st126010)</strong> | A4 Assignment</p>
91
+ </footer>
92
+ </div>
93
+
94
+ <script>
95
+ const scenarios = {
96
+ 'climate_fever': {
97
+ claims: [
98
+ "Global warming is caused by human activities.",
99
+ "Sea levels are rising due to melting ice caps.",
100
+ "The sun is the primary driver of recent climate change."
101
+ ],
102
+ evidence: [
103
+ "The IPCC report confirms that human influence has warmed the atmosphere, ocean and land.",
104
+ "Satellite data shows a steady increase in global sea levels over the past century.",
105
+ "Solar irradiance has remained relatively stable while temperatures have soared."
106
+ ]
107
+ },
108
+ 'snli': {
109
+ claims: [
110
+ "A soccer player is running across the field.",
111
+ "A person is inspecting the tires of a bicycle.",
112
+ "Two men are playing basketball."
113
+ ],
114
+ evidence: [
115
+ "A person is moving fast on a grass surface.",
116
+ "A mechanic is fixing a car.",
117
+ "The men are playing a sport."
118
+ ]
119
+ },
120
+ 'mnli': {
121
+ claims: [
122
+ "The government announced a new tax policy.",
123
+ "He turned and looked at the woman.",
124
+ "The concert was cancelled due to rain."
125
+ ],
126
+ evidence: [
127
+ "New financial regulations were introduced by the state.",
128
+ "He ignored the person standing next to him.",
129
+ "The outdoor event proceeded despite the bad weather."
130
+ ]
131
+ }
132
+ };
133
+
134
+ function loadScenario() {
135
+ const select = document.getElementById('dataset-select');
136
+ const scenarioKey = select.value;
137
+ const claimsList = document.getElementById('claims-list');
138
+ const evidenceList = document.getElementById('evidence-list');
139
+ const s1Input = document.getElementById('sentence1');
140
+ const s2Input = document.getElementById('sentence2');
141
+
142
+ // Clear lists
143
+ claimsList.innerHTML = '';
144
+ evidenceList.innerHTML = '';
145
+
146
+ if (scenarioKey === 'custom') {
147
+ s1Input.value = '';
148
+ s2Input.value = '';
149
+ return;
150
+ }
151
+
152
+ const data = scenarios[scenarioKey];
153
+
154
+ // Populate Datalists
155
+ data.claims.forEach(item => {
156
+ const opt = document.createElement('option');
157
+ opt.value = item;
158
+ claimsList.appendChild(opt);
159
+ });
160
+ data.evidence.forEach(item => {
161
+ const opt = document.createElement('option');
162
+ opt.value = item;
163
+ evidenceList.appendChild(opt);
164
+ });
165
+
166
+ // Auto-fill first example for convenience
167
+ s1Input.value = data.claims[0];
168
+ s2Input.value = data.evidence[0];
169
+ }
170
+
171
+ async function predict() {
172
+ const s1 = document.getElementById('sentence1').value;
173
+ const s2 = document.getElementById('sentence2').value;
174
+ const modelType = document.getElementById('dataset-select').value; // Get selected model
175
+ const resultDiv = document.getElementById('result');
176
+ const btn = document.getElementById('analyze-btn');
177
+
178
+ if (!s1 || !s2) {
179
+ alert("Please enter both sentences.");
180
+ return;
181
+ }
182
+
183
+ btn.textContent = "Analyzing...";
184
+ resultDiv.classList.add('hidden');
185
+
186
+ try {
187
+ const response = await fetch('/predict', {
188
+ method: 'POST',
189
+ headers: { 'Content-Type': 'application/json' },
190
+ body: JSON.stringify({
191
+ sentence1: s1,
192
+ sentence2: s2,
193
+ model_type: modelType // Send model type
194
+ })
195
+ });
196
+
197
+ const data = await response.json();
198
+
199
+ if (data.error) {
200
+ alert(data.error);
201
+ return;
202
+ }
203
+
204
+ // Update UI
205
+ const label = document.getElementById('prediction-label');
206
+ label.textContent = data.prediction;
207
+
208
+ // Color coding
209
+ if (data.prediction === 'Entailment') label.style.color = '#10B981'; // Green
210
+ else if (data.prediction === 'Contradiction') label.style.color = '#EF4444'; // Red
211
+ else label.style.color = '#F59E0B'; // Yellow/Orange
212
+
213
+ // Update Bars
214
+ document.getElementById('bar-entailment').style.width = (data.probabilities.Entailment * 100) + '%';
215
+ document.getElementById('val-entailment').textContent = (data.probabilities.Entailment * 100).toFixed(1) + '%';
216
+
217
+ document.getElementById('bar-neutral').style.width = (data.probabilities.Neutral * 100) + '%';
218
+ document.getElementById('val-neutral').textContent = (data.probabilities.Neutral * 100).toFixed(1) + '%';
219
+
220
+ document.getElementById('bar-contradiction').style.width = (data.probabilities.Contradiction * 100) + '%';
221
+ document.getElementById('val-contradiction').textContent = (data.probabilities.Contradiction * 100).toFixed(1) + '%';
222
+
223
+ resultDiv.classList.remove('hidden');
224
+
225
+ } catch (e) {
226
+ console.error(e);
227
+ alert("Error connecting to the server.");
228
+ } finally {
229
+ btn.textContent = "Analyze Similarity";
230
+ }
231
+ }
232
+
233
+ // Initialize on load
234
+ window.onload = loadScenario;
235
+ </script>
236
+ </body>
237
+ </html>
demo.gif ADDED

Git LFS Details

  • SHA256: 959470bc03c4e9f662c1d28dfda8583981edb722294ca5b67f4175f6567d7212
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
models/bert_trained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bd87012116d7955bc8b4fddbe591810bc70692ed1df20d3394e58d08bdbc58a
3
+ size 12138238
models/sbert_climate_fever.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfd76cef3670cbd791bc83a0bf7292cb6975ae8b190ffb02e8feb92b350e33bb
3
+ size 12148818
models/sbert_mnli.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a5397be33fe31a270905838e93ddf94a05c2f1e54f09c13c5a453b45f580fc5
3
+ size 12148386
models/sbert_snli.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7c3557c784c924a487f55a6de15d8c127599aca0cfa74a1e6655b9af7b063c2
3
+ size 12148386
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers
3
+ datasets
4
+ scikit-learn
5
+ pandas
6
+ numpy
7
+ flask
8
+ gunicorn
9
+ tqdm