Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +9 -0
- A4_BERT.ipynb +383 -0
- A4_Climate_FEVER.ipynb +522 -0
- A4_Option_MNLI.ipynb +518 -0
- A4_Option_SNLI.ipynb +637 -0
- Dockerfile +11 -0
- README.md +83 -6
- app/app.py +316 -0
- app/static/style.css +234 -0
- app/templates/index.html +237 -0
- demo.gif +3 -0
- models/bert_trained.pt +3 -0
- models/sbert_climate_fever.pt +3 -0
- models/sbert_mnli.pt +3 -0
- models/sbert_snli.pt +3 -0
- requirements.txt +9 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
demo.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.ipynb_checkpoints/
|
| 4 |
+
.DS_Store
|
| 5 |
+
*.env
|
| 6 |
+
venv/
|
| 7 |
+
env/
|
| 8 |
+
# models/*.pt
|
| 9 |
+
# If using large models, add them to LFS instead
|
A4_BERT.ipynb
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "TB4CfcNaQZFN"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# A4: BERT Pre-training from Scratch\n",
|
| 10 |
+
"## Student Information\n",
|
| 11 |
+
"**Name:** HTUT KO KO \n",
|
| 12 |
+
"**ID:** st126010 \n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Task 1: BERT implementation\n",
|
| 15 |
+
"In this notebook, I implement BERT from scratch and pre-train it on the WikiText-103 dataset.\n",
|
| 16 |
+
"**Optimization Note:** To achieve lower loss on this small-scale demonstration, I use a smaller subset of data, a smaller vocabulary, and run for more adequate epochs."
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": 1,
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [
|
| 24 |
+
{
|
| 25 |
+
"name": "stderr",
|
| 26 |
+
"output_type": "stream",
|
| 27 |
+
"text": [
|
| 28 |
+
"/opt/homebrew/Caskroom/miniforge/base/envs/ai_env/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 29 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"name": "stdout",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"Using device: mps\n"
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"source": [
|
| 41 |
+
"import torch\n",
|
| 42 |
+
"import torch.nn as nn\n",
|
| 43 |
+
"import torch.optim as optim\n",
|
| 44 |
+
"import numpy as np\n",
|
| 45 |
+
"import random\n",
|
| 46 |
+
"from random import randrange, shuffle, randint\n",
|
| 47 |
+
"from datasets import load_dataset\n",
|
| 48 |
+
"from transformers import BertTokenizer\n",
|
| 49 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 50 |
+
"import re\n",
|
| 51 |
+
"from collections import Counter\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
|
| 54 |
+
"print(f\"Using device: {device}\")"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "markdown",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"source": [
|
| 61 |
+
"### 1. Data Loading & Preprocessing\n",
|
| 62 |
+
"I will use a smaller vocabulary size to make the model convergence easier for the assignment."
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": 2,
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [
|
| 70 |
+
{
|
| 71 |
+
"name": "stdout",
|
| 72 |
+
"output_type": "stream",
|
| 73 |
+
"text": [
|
| 74 |
+
"Vocab Size: 5004\n"
|
| 75 |
+
]
|
| 76 |
+
}
|
| 77 |
+
],
|
| 78 |
+
"source": [
|
| 79 |
+
"# 1. Load Data\n",
|
| 80 |
+
"dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', split='train')\n",
|
| 81 |
+
"subset_size = 5000 # Reduce size for faster iteration and better convergence on small model\n",
|
| 82 |
+
"dataset = dataset.select(range(subset_size))\n",
|
| 83 |
+
"raw_text_data = [line for line in dataset['text'] if len(line) > 20]\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# 2. Build Custom Vocabulary (Crucial for small data)\n",
|
| 86 |
+
"# Instead of using 30k BERT vocab, we build one from our data\n",
|
| 87 |
+
"tokens = [word.lower() for sent in raw_text_data for word in sent.split()]\n",
|
| 88 |
+
"vocab_counter = Counter(tokens)\n",
|
| 89 |
+
"vocab = sorted(vocab_counter, key=vocab_counter.get, reverse=True)[:5000] # Top 5k words\n",
|
| 90 |
+
"word2id = {w: i+4 for i, w in enumerate(vocab)}\n",
|
| 91 |
+
"word2id['[PAD]'] = 0\n",
|
| 92 |
+
"word2id['[CLS]'] = 1\n",
|
| 93 |
+
"word2id['[SEP]'] = 2\n",
|
| 94 |
+
"word2id['[MASK]'] = 3\n",
|
| 95 |
+
"id2word = {i: w for w, i in word2id.items()}\n",
|
| 96 |
+
"vocab_size = len(word2id)\n",
|
| 97 |
+
"print(f\"Vocab Size: {vocab_size}\")\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"token_list = []\n",
|
| 100 |
+
"for sentence in raw_text_data:\n",
|
| 101 |
+
" # Simple whitespace tokenization for this demo\n",
|
| 102 |
+
" seq = [word2id.get(w.lower(), 0) for w in sentence.split()] \n",
|
| 103 |
+
" if len(seq) > 0:\n",
|
| 104 |
+
" token_list.append(seq)"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "markdown",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"source": [
|
| 111 |
+
"### 2. BERT Hyperparameters & Data Loader"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": 3,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"max_len = 128\n",
|
| 121 |
+
"batch_size = 16 # Small batch size\n",
|
| 122 |
+
"max_mask = 20\n",
|
| 123 |
+
"n_layers = 2 # Shallower model for easier training\n",
|
| 124 |
+
"n_heads = 4\n",
|
| 125 |
+
"d_model = 256\n",
|
| 126 |
+
"d_ff = 256 * 4\n",
|
| 127 |
+
"d_k = d_v = 64\n",
|
| 128 |
+
"n_segments = 2\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"def make_batch():\n",
|
| 131 |
+
" batch = []\n",
|
| 132 |
+
" positive = negative = 0\n",
|
| 133 |
+
" while positive != batch_size / 2 or negative != batch_size / 2:\n",
|
| 134 |
+
" tokens_a_index, tokens_b_index = randrange(len(token_list)), randrange(len(token_list))\n",
|
| 135 |
+
" tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]\n",
|
| 136 |
+
"\n",
|
| 137 |
+
" input_ids = [word2id['[CLS]']] + tokens_a + [word2id['[SEP]']] + tokens_b + [word2id['[SEP]']]\n",
|
| 138 |
+
" segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
" input_ids = input_ids[:max_len]\n",
|
| 141 |
+
" segment_ids = segment_ids[:max_len]\n",
|
| 142 |
+
"\n",
|
| 143 |
+
" n_pred = min(max_mask, max(1, int(round(len(input_ids) * 0.15))))\n",
|
| 144 |
+
" candidates_masked_pos = [i for i, token in enumerate(input_ids) if token != word2id['[CLS]'] and token != word2id['[SEP]']]\n",
|
| 145 |
+
" shuffle(candidates_masked_pos)\n",
|
| 146 |
+
" masked_tokens, masked_pos = [], []\n",
|
| 147 |
+
" for pos in candidates_masked_pos[:n_pred]:\n",
|
| 148 |
+
" masked_pos.append(pos)\n",
|
| 149 |
+
" masked_tokens.append(input_ids[pos])\n",
|
| 150 |
+
" if random.random() < 0.1:\n",
|
| 151 |
+
" input_ids[pos] = randint(0, vocab_size - 1)\n",
|
| 152 |
+
" elif random.random() < 0.8:\n",
|
| 153 |
+
" input_ids[pos] = word2id['[MASK]']\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" n_pad = max_len - len(input_ids)\n",
|
| 156 |
+
" input_ids.extend([0] * n_pad)\n",
|
| 157 |
+
" segment_ids.extend([0] * n_pad)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" if max_mask > n_pred:\n",
|
| 160 |
+
" n_pad = max_mask - n_pred\n",
|
| 161 |
+
" masked_tokens.extend([0] * n_pad)\n",
|
| 162 |
+
" masked_pos.extend([0] * n_pad)\n",
|
| 163 |
+
"\n",
|
| 164 |
+
" if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:\n",
|
| 165 |
+
" batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])\n",
|
| 166 |
+
" positive += 1\n",
|
| 167 |
+
" elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:\n",
|
| 168 |
+
" batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])\n",
|
| 169 |
+
" negative += 1\n",
|
| 170 |
+
" return batch"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "markdown",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"source": [
|
| 177 |
+
"### 3. BERT Model Architecture"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": 4,
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": [
|
| 186 |
+
"class Embedding(nn.Module):\n",
|
| 187 |
+
" def __init__(self):\n",
|
| 188 |
+
" super(Embedding, self).__init__()\n",
|
| 189 |
+
" self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
|
| 190 |
+
" self.pos_embed = nn.Embedding(max_len, d_model)\n",
|
| 191 |
+
" self.seg_embed = nn.Embedding(n_segments, d_model)\n",
|
| 192 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 193 |
+
"\n",
|
| 194 |
+
" def forward(self, x, seg):\n",
|
| 195 |
+
" seq_len = x.size(1)\n",
|
| 196 |
+
" pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
|
| 197 |
+
" pos = pos.unsqueeze(0).expand_as(x)\n",
|
| 198 |
+
" embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
|
| 199 |
+
" return self.norm(embedding)\n",
|
| 200 |
+
"\n",
|
| 201 |
+
"class MultiHeadAttention(nn.Module):\n",
|
| 202 |
+
" def __init__(self):\n",
|
| 203 |
+
" super(MultiHeadAttention, self).__init__()\n",
|
| 204 |
+
" self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
|
| 205 |
+
" self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
|
| 206 |
+
" self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
|
| 207 |
+
" self.linear = nn.Linear(n_heads * d_v, d_model)\n",
|
| 208 |
+
" self.layer_norm = nn.LayerNorm(d_model)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
" def forward(self, Q, K, V, attn_mask):\n",
|
| 211 |
+
" batch_size = Q.size(0)\n",
|
| 212 |
+
" q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 213 |
+
" k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 214 |
+
" v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
|
| 215 |
+
" \n",
|
| 216 |
+
" attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
|
| 217 |
+
" \n",
|
| 218 |
+
" scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
|
| 219 |
+
" scores.masked_fill_(attn_mask, -1e9)\n",
|
| 220 |
+
" attn = nn.Softmax(dim=-1)(scores)\n",
|
| 221 |
+
" context = torch.matmul(attn, v_s)\n",
|
| 222 |
+
" context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
|
| 223 |
+
" output = self.linear(context)\n",
|
| 224 |
+
" return self.layer_norm(output + Q), attn\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"class PoswiseFeedForwardNet(nn.Module):\n",
|
| 227 |
+
" def __init__(self):\n",
|
| 228 |
+
" super(PoswiseFeedForwardNet, self).__init__()\n",
|
| 229 |
+
" self.fc1 = nn.Linear(d_model, d_ff)\n",
|
| 230 |
+
" self.fc2 = nn.Linear(d_ff, d_model)\n",
|
| 231 |
+
"\n",
|
| 232 |
+
" def forward(self, x):\n",
|
| 233 |
+
" return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"class EncoderLayer(nn.Module):\n",
|
| 236 |
+
" def __init__(self):\n",
|
| 237 |
+
" super(EncoderLayer, self).__init__()\n",
|
| 238 |
+
" self.enc_self_attn = MultiHeadAttention()\n",
|
| 239 |
+
" self.pos_ffn = PoswiseFeedForwardNet()\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" def forward(self, enc_inputs, enc_self_attn_mask):\n",
|
| 242 |
+
" enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
|
| 243 |
+
" enc_outputs = self.pos_ffn(enc_outputs)\n",
|
| 244 |
+
" return enc_outputs, attn\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"def get_attn_pad_mask(seq_q, seq_k):\n",
|
| 247 |
+
" batch_size, len_q = seq_q.size()\n",
|
| 248 |
+
" batch_size, len_k = seq_k.size()\n",
|
| 249 |
+
" pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
|
| 250 |
+
" return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"class BERT(nn.Module):\n",
|
| 253 |
+
" def __init__(self):\n",
|
| 254 |
+
" super(BERT, self).__init__()\n",
|
| 255 |
+
" self.embedding = Embedding()\n",
|
| 256 |
+
" self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
|
| 257 |
+
" self.fc = nn.Linear(d_model, d_model)\n",
|
| 258 |
+
" self.activ = nn.Tanh()\n",
|
| 259 |
+
" self.linear = nn.Linear(d_model, d_model)\n",
|
| 260 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 261 |
+
" self.classifier = nn.Linear(d_model, 2)\n",
|
| 262 |
+
" embed_weight = self.embedding.tok_embed.weight\n",
|
| 263 |
+
" n_vocab, n_dim = embed_weight.size()\n",
|
| 264 |
+
" self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
|
| 265 |
+
" self.decoder.weight = embed_weight\n",
|
| 266 |
+
" self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" def forward(self, input_ids, segment_ids, masked_pos):\n",
|
| 269 |
+
" output = self.embedding(input_ids, segment_ids)\n",
|
| 270 |
+
" enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
|
| 271 |
+
" for layer in self.layers:\n",
|
| 272 |
+
" output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
|
| 273 |
+
" \n",
|
| 274 |
+
" h_pooled = self.activ(self.fc(output[:, 0]))\n",
|
| 275 |
+
" logits_nsp = self.classifier(h_pooled)\n",
|
| 276 |
+
" \n",
|
| 277 |
+
" masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)\n",
|
| 278 |
+
" h_masked = torch.gather(output, 1, masked_pos)\n",
|
| 279 |
+
" h_masked = self.norm(self.activ(self.linear(h_masked)))\n",
|
| 280 |
+
" logits_lm = self.decoder(h_masked) + self.decoder_bias\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" return logits_lm, logits_nsp, output"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "markdown",
|
| 287 |
+
"metadata": {},
|
| 288 |
+
"source": [
|
| 289 |
+
"### 4. Training Loop\n",
|
| 290 |
+
"I train for 2000 epochs to ensure convergence."
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": 6,
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"outputs": [
|
| 298 |
+
{
|
| 299 |
+
"name": "stdout",
|
| 300 |
+
"output_type": "stream",
|
| 301 |
+
"text": [
|
| 302 |
+
"Starting Training...\n",
|
| 303 |
+
"Epoch: 0100 | loss = 13.480813\n",
|
| 304 |
+
"Epoch: 0200 | loss = 10.269491\n",
|
| 305 |
+
"Epoch: 0300 | loss = 8.912387\n",
|
| 306 |
+
"Epoch: 0400 | loss = 7.265475\n",
|
| 307 |
+
"Epoch: 0500 | loss = 6.765235\n",
|
| 308 |
+
"Epoch: 0600 | loss = 7.150949\n",
|
| 309 |
+
"Epoch: 0700 | loss = 6.182394\n",
|
| 310 |
+
"Epoch: 0800 | loss = 6.075039\n",
|
| 311 |
+
"Epoch: 0900 | loss = 6.766500\n",
|
| 312 |
+
"Epoch: 1000 | loss = 6.545547\n",
|
| 313 |
+
"Epoch: 1100 | loss = 6.488539\n",
|
| 314 |
+
"Epoch: 1200 | loss = 6.223000\n",
|
| 315 |
+
"Epoch: 1300 | loss = 5.912578\n",
|
| 316 |
+
"Epoch: 1400 | loss = 6.125433\n",
|
| 317 |
+
"Epoch: 1500 | loss = 6.077301\n",
|
| 318 |
+
"Epoch: 1600 | loss = 6.500366\n",
|
| 319 |
+
"Epoch: 1700 | loss = 6.560534\n",
|
| 320 |
+
"Epoch: 1800 | loss = 6.262241\n",
|
| 321 |
+
"Epoch: 1900 | loss = 5.871750\n",
|
| 322 |
+
"Epoch: 2000 | loss = 6.158124\n",
|
| 323 |
+
"Training Complete. Model Saved.\n"
|
| 324 |
+
]
|
| 325 |
+
}
|
| 326 |
+
],
|
| 327 |
+
"source": [
|
| 328 |
+
"model = BERT().to(device)\n",
|
| 329 |
+
"criterion = nn.CrossEntropyLoss(ignore_index=0)\n",
|
| 330 |
+
"criterion_nsp = nn.CrossEntropyLoss()\n",
|
| 331 |
+
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"print(\"Starting Training...\")\n",
|
| 334 |
+
"for epoch in range(2000):\n",
|
| 335 |
+
" batch = make_batch()\n",
|
| 336 |
+
" input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))\n",
|
| 337 |
+
" input_ids, segment_ids, masked_tokens, masked_pos, isNext = input_ids.to(device), segment_ids.to(device), masked_tokens.to(device), masked_pos.to(device), isNext.to(device)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" optimizer.zero_grad()\n",
|
| 340 |
+
" logits_lm, logits_nsp, _ = model(input_ids, segment_ids, masked_pos)\n",
|
| 341 |
+
" loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens).mean()\n",
|
| 342 |
+
" loss_nsp = criterion_nsp(logits_nsp, isNext)\n",
|
| 343 |
+
" loss = loss_lm + loss_nsp\n",
|
| 344 |
+
" loss.backward()\n",
|
| 345 |
+
" optimizer.step()\n",
|
| 346 |
+
" \n",
|
| 347 |
+
" if (epoch + 1) % 100 == 0:\n",
|
| 348 |
+
" print(f'Epoch: {epoch + 1:04d} | loss = {loss.item():.6f}')\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"torch.save(model.state_dict(), './models/bert_trained.pt')\n",
|
| 351 |
+
"print(\"Training Complete. Model Saved.\")"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "code",
|
| 356 |
+
"execution_count": null,
|
| 357 |
+
"metadata": {},
|
| 358 |
+
"outputs": [],
|
| 359 |
+
"source": []
|
| 360 |
+
}
|
| 361 |
+
],
|
| 362 |
+
"metadata": {
|
| 363 |
+
"kernelspec": {
|
| 364 |
+
"display_name": "ai_env",
|
| 365 |
+
"language": "python",
|
| 366 |
+
"name": "python3"
|
| 367 |
+
},
|
| 368 |
+
"language_info": {
|
| 369 |
+
"codemirror_mode": {
|
| 370 |
+
"name": "ipython",
|
| 371 |
+
"version": 3
|
| 372 |
+
},
|
| 373 |
+
"file_extension": ".py",
|
| 374 |
+
"mimetype": "text/x-python",
|
| 375 |
+
"name": "python",
|
| 376 |
+
"nbconvert_exporter": "python",
|
| 377 |
+
"pygments_lexer": "ipython3",
|
| 378 |
+
"version": "3.11.13"
|
| 379 |
+
}
|
| 380 |
+
},
|
| 381 |
+
"nbformat": 4,
|
| 382 |
+
"nbformat_minor": 4
|
| 383 |
+
}
|
A4_Climate_FEVER.ipynb
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "TB4CfcNaQZFN"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"# A4: S-BERT Training (Climate-FEVER)\n",
|
| 10 |
+
"## Student Information\n",
|
| 11 |
+
"**Name:** HTUT KO KO \n",
|
| 12 |
+
"**ID:** st126010 \n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Task 2: S-BERT Implementation\n",
|
| 15 |
+
"In this notebook, I load the pre-trained BERT model (from Task 1) and fine-tune it using a Siamese network structure for Natural Language Inference (NLI) on the Climate-FEVER dataset."
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 9,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [
|
| 23 |
+
{
|
| 24 |
+
"name": "stdout",
|
| 25 |
+
"output_type": "stream",
|
| 26 |
+
"text": [
|
| 27 |
+
"Using device: mps\n"
|
| 28 |
+
]
|
| 29 |
+
}
|
| 30 |
+
],
|
| 31 |
+
"source": [
|
| 32 |
+
"import torch\n",
|
| 33 |
+
"import torch.nn as nn\n",
|
| 34 |
+
"import torch.optim as optim\n",
|
| 35 |
+
"import numpy as np\n",
|
| 36 |
+
"from datasets import load_dataset\n",
|
| 37 |
+
"from transformers import BertTokenizer\n",
|
| 38 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 39 |
+
"from sklearn.metrics import classification_report, accuracy_score\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
|
| 42 |
+
"print(f\"Using device: {device}\")"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 10,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"# Model Hyperparameters (Must match Pre-training)\n",
|
| 52 |
+
"max_len = 128\n",
|
| 53 |
+
"n_layers = 2\n",
|
| 54 |
+
"n_heads = 4\n",
|
| 55 |
+
"d_model = 256\n",
|
| 56 |
+
"d_ff = 256 * 4\n",
|
| 57 |
+
"d_k = d_v = 64\n",
|
| 58 |
+
"n_segments = 2\n",
|
| 59 |
+
"vocab_size = 5004"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"cell_type": "markdown",
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"source": [
|
| 66 |
+
"## 1. BERT Model Definition\n",
|
| 67 |
+
"Required to load the saved state dictionary."
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": 11,
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [
|
| 75 |
+
{
|
| 76 |
+
"name": "stdout",
|
| 77 |
+
"output_type": "stream",
|
| 78 |
+
"text": [
|
| 79 |
+
"Loaded bert_trained.pt successfully.\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"class Embedding(nn.Module):\n",
|
| 85 |
+
" def __init__(self):\n",
|
| 86 |
+
" super(Embedding, self).__init__()\n",
|
| 87 |
+
" self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
|
| 88 |
+
" self.pos_embed = nn.Embedding(max_len, d_model)\n",
|
| 89 |
+
" self.seg_embed = nn.Embedding(n_segments, d_model)\n",
|
| 90 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" def forward(self, x, seg):\n",
|
| 93 |
+
" seq_len = x.size(1)\n",
|
| 94 |
+
" pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
|
| 95 |
+
" pos = pos.unsqueeze(0).expand_as(x)\n",
|
| 96 |
+
" embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
|
| 97 |
+
" return self.norm(embedding)\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"class MultiHeadAttention(nn.Module):\n",
|
| 100 |
+
" def __init__(self):\n",
|
| 101 |
+
" super(MultiHeadAttention, self).__init__()\n",
|
| 102 |
+
" self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
|
| 103 |
+
" self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
|
| 104 |
+
" self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
|
| 105 |
+
" self.linear = nn.Linear(n_heads * d_v, d_model)\n",
|
| 106 |
+
" self.layer_norm = nn.LayerNorm(d_model)\n",
|
| 107 |
+
"\n",
|
| 108 |
+
" def forward(self, Q, K, V, attn_mask):\n",
|
| 109 |
+
" batch_size = Q.size(0)\n",
|
| 110 |
+
" q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 111 |
+
" k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 112 |
+
" v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
|
| 113 |
+
" \n",
|
| 114 |
+
" attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
|
| 117 |
+
" scores.masked_fill_(attn_mask, -1e9)\n",
|
| 118 |
+
" attn = nn.Softmax(dim=-1)(scores)\n",
|
| 119 |
+
" context = torch.matmul(attn, v_s)\n",
|
| 120 |
+
" context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
|
| 121 |
+
" output = self.linear(context)\n",
|
| 122 |
+
" return self.layer_norm(output + Q), attn\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"class PoswiseFeedForwardNet(nn.Module):\n",
|
| 125 |
+
" def __init__(self):\n",
|
| 126 |
+
" super(PoswiseFeedForwardNet, self).__init__()\n",
|
| 127 |
+
" self.fc1 = nn.Linear(d_model, d_ff)\n",
|
| 128 |
+
" self.fc2 = nn.Linear(d_ff, d_model)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
" def forward(self, x):\n",
|
| 131 |
+
" return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"class EncoderLayer(nn.Module):\n",
|
| 134 |
+
" def __init__(self):\n",
|
| 135 |
+
" super(EncoderLayer, self).__init__()\n",
|
| 136 |
+
" self.enc_self_attn = MultiHeadAttention()\n",
|
| 137 |
+
" self.pos_ffn = PoswiseFeedForwardNet()\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" def forward(self, enc_inputs, enc_self_attn_mask):\n",
|
| 140 |
+
" enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
|
| 141 |
+
" enc_outputs = self.pos_ffn(enc_outputs)\n",
|
| 142 |
+
" return enc_outputs, attn\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"def get_attn_pad_mask(seq_q, seq_k):\n",
|
| 145 |
+
" batch_size, len_q = seq_q.size()\n",
|
| 146 |
+
" batch_size, len_k = seq_k.size()\n",
|
| 147 |
+
" pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
|
| 148 |
+
" return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"class BERT(nn.Module):\n",
|
| 151 |
+
" def __init__(self):\n",
|
| 152 |
+
" super(BERT, self).__init__()\n",
|
| 153 |
+
" self.embedding = Embedding()\n",
|
| 154 |
+
" self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
|
| 155 |
+
" self.fc = nn.Linear(d_model, d_model)\n",
|
| 156 |
+
" self.activ = nn.Tanh()\n",
|
| 157 |
+
" self.linear = nn.Linear(d_model, d_model)\n",
|
| 158 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 159 |
+
" self.classifier = nn.Linear(d_model, 2)\n",
|
| 160 |
+
" embed_weight = self.embedding.tok_embed.weight\n",
|
| 161 |
+
" n_vocab, n_dim = embed_weight.size()\n",
|
| 162 |
+
" self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
|
| 163 |
+
" self.decoder.weight = embed_weight\n",
|
| 164 |
+
" self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
|
| 165 |
+
"\n",
|
| 166 |
+
" def forward(self, input_ids, segment_ids, masked_pos):\n",
|
| 167 |
+
" output = self.embedding(input_ids, segment_ids)\n",
|
| 168 |
+
" enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
|
| 169 |
+
" for layer in self.layers:\n",
|
| 170 |
+
" output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" h_pooled = self.activ(self.fc(output[:, 0]))\n",
|
| 173 |
+
" logits_nsp = self.classifier(h_pooled)\n",
|
| 174 |
+
" \n",
|
| 175 |
+
" # For S-BERT, I return the output sequences directly\n",
|
| 176 |
+
" return logits_nsp, logits_nsp, output\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"# Load Pre-trained Parameters\n",
|
| 179 |
+
"bert = BERT().to(device)\n",
|
| 180 |
+
"try:\n",
|
| 181 |
+
" bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
|
| 182 |
+
" print(\"Loaded bert_trained.pt successfully.\")\n",
|
| 183 |
+
"except:\n",
|
| 184 |
+
" print(\"Pre-trained weights not found. Please run A4_BERT.ipynb first.\")"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "markdown",
|
| 189 |
+
"metadata": {},
|
| 190 |
+
"source": [
|
| 191 |
+
"## 2. S-BERT for Climate-FEVER\n",
|
| 192 |
+
"Fine-tuning on the Climate-FEVER dataset."
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": 12,
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"outputs": [],
|
| 200 |
+
"source": [
|
| 201 |
+
"class SBERT(nn.Module):\n",
|
| 202 |
+
" def __init__(self, bert_model):\n",
|
| 203 |
+
" super(SBERT, self).__init__()\n",
|
| 204 |
+
" self.bert = bert_model\n",
|
| 205 |
+
" self.classifier = nn.Linear(d_model * 3, 3)\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
|
| 208 |
+
" device = premise_ids.device\n",
|
| 209 |
+
" dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
|
| 210 |
+
" \n",
|
| 211 |
+
" _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
|
| 212 |
+
" mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
|
| 213 |
+
" u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
" _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
|
| 216 |
+
" mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
|
| 217 |
+
" v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" uv_abs = torch.abs(u - v)\n",
|
| 220 |
+
" features = torch.cat([u, v, uv_abs], dim=-1)\n",
|
| 221 |
+
" logits = self.classifier(features)\n",
|
| 222 |
+
" return logits\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"s_model = SBERT(bert).to(device)\n",
|
| 225 |
+
"optimizer = optim.Adam(s_model.parameters(), lr=2e-5)\n",
|
| 226 |
+
"criterion = nn.CrossEntropyLoss()"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "markdown",
|
| 231 |
+
"metadata": {},
|
| 232 |
+
"source": [
|
| 233 |
+
"### 2.1 Climate-FEVER fine-tuning\n",
|
| 234 |
+
"I load the Climate-FEVER dataset, split into train/test, and train the model."
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": 13,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"outputs": [
|
| 242 |
+
{
|
| 243 |
+
"name": "stdout",
|
| 244 |
+
"output_type": "stream",
|
| 245 |
+
"text": [
|
| 246 |
+
"Starting S-BERT Training...\n",
|
| 247 |
+
"Epoch 1 Loss: 0.9086\n",
|
| 248 |
+
"Epoch 2 Loss: 0.8578\n",
|
| 249 |
+
"Epoch 3 Loss: 0.8520\n",
|
| 250 |
+
"Epoch 4 Loss: 0.8444\n",
|
| 251 |
+
"Epoch 5 Loss: 0.8365\n",
|
| 252 |
+
"Epoch 6 Loss: 0.8249\n",
|
| 253 |
+
"Epoch 7 Loss: 0.8103\n",
|
| 254 |
+
"Epoch 8 Loss: 0.7932\n",
|
| 255 |
+
"Epoch 9 Loss: 0.7652\n",
|
| 256 |
+
"Epoch 10 Loss: 0.7349\n",
|
| 257 |
+
"Epoch 11 Loss: 0.6856\n",
|
| 258 |
+
"Epoch 12 Loss: 0.6360\n",
|
| 259 |
+
"Epoch 13 Loss: 0.5786\n",
|
| 260 |
+
"Epoch 14 Loss: 0.5181\n",
|
| 261 |
+
"Epoch 15 Loss: 0.4629\n",
|
| 262 |
+
"Epoch 16 Loss: 0.4108\n",
|
| 263 |
+
"Epoch 17 Loss: 0.3392\n",
|
| 264 |
+
"Epoch 18 Loss: 0.3047\n",
|
| 265 |
+
"Epoch 19 Loss: 0.2546\n",
|
| 266 |
+
"Epoch 20 Loss: 0.2053\n",
|
| 267 |
+
"Epoch 21 Loss: 0.1662\n",
|
| 268 |
+
"Epoch 22 Loss: 0.1421\n",
|
| 269 |
+
"Epoch 23 Loss: 0.1161\n",
|
| 270 |
+
"Epoch 24 Loss: 0.0918\n",
|
| 271 |
+
"Epoch 25 Loss: 0.0798\n",
|
| 272 |
+
"Epoch 26 Loss: 0.0740\n",
|
| 273 |
+
"Epoch 27 Loss: 0.0591\n",
|
| 274 |
+
"Epoch 28 Loss: 0.0488\n",
|
| 275 |
+
"Epoch 29 Loss: 0.0448\n",
|
| 276 |
+
"Epoch 30 Loss: 0.0452\n",
|
| 277 |
+
"Epoch 31 Loss: 0.0354\n",
|
| 278 |
+
"Epoch 32 Loss: 0.0315\n",
|
| 279 |
+
"Epoch 33 Loss: 0.0265\n",
|
| 280 |
+
"Epoch 34 Loss: 0.0232\n",
|
| 281 |
+
"Epoch 35 Loss: 0.0215\n",
|
| 282 |
+
"Epoch 36 Loss: 0.0180\n",
|
| 283 |
+
"Epoch 37 Loss: 0.0173\n",
|
| 284 |
+
"Epoch 38 Loss: 0.0147\n",
|
| 285 |
+
"Epoch 39 Loss: 0.0137\n",
|
| 286 |
+
"Epoch 40 Loss: 0.0159\n",
|
| 287 |
+
"Epoch 41 Loss: 0.0127\n",
|
| 288 |
+
"Epoch 42 Loss: 0.0102\n",
|
| 289 |
+
"Epoch 43 Loss: 0.0094\n",
|
| 290 |
+
"Epoch 44 Loss: 0.0094\n",
|
| 291 |
+
"Epoch 45 Loss: 0.0100\n",
|
| 292 |
+
"Epoch 46 Loss: 0.0112\n",
|
| 293 |
+
"Epoch 47 Loss: 0.0077\n",
|
| 294 |
+
"Epoch 48 Loss: 0.0067\n",
|
| 295 |
+
"Epoch 49 Loss: 0.0073\n",
|
| 296 |
+
"Epoch 50 Loss: 0.0268\n",
|
| 297 |
+
"Epoch 51 Loss: 0.0747\n",
|
| 298 |
+
"Epoch 52 Loss: 0.0405\n",
|
| 299 |
+
"Epoch 53 Loss: 0.0241\n",
|
| 300 |
+
"Epoch 54 Loss: 0.0077\n",
|
| 301 |
+
"Epoch 55 Loss: 0.0054\n",
|
| 302 |
+
"Epoch 56 Loss: 0.0049\n",
|
| 303 |
+
"Epoch 57 Loss: 0.0047\n",
|
| 304 |
+
"Epoch 58 Loss: 0.0047\n",
|
| 305 |
+
"Epoch 59 Loss: 0.0052\n",
|
| 306 |
+
"Epoch 60 Loss: 0.0038\n",
|
| 307 |
+
"Epoch 61 Loss: 0.0037\n",
|
| 308 |
+
"Epoch 62 Loss: 0.0039\n",
|
| 309 |
+
"Epoch 63 Loss: 0.0037\n",
|
| 310 |
+
"Epoch 64 Loss: 0.0049\n",
|
| 311 |
+
"Epoch 65 Loss: 0.0036\n",
|
| 312 |
+
"Epoch 66 Loss: 0.0035\n",
|
| 313 |
+
"Epoch 67 Loss: 0.0037\n",
|
| 314 |
+
"Epoch 68 Loss: 0.0040\n",
|
| 315 |
+
"Epoch 69 Loss: 0.0044\n",
|
| 316 |
+
"Epoch 70 Loss: 0.0038\n",
|
| 317 |
+
"Epoch 71 Loss: 0.0079\n",
|
| 318 |
+
"Epoch 72 Loss: 0.0093\n",
|
| 319 |
+
"Epoch 73 Loss: 0.0033\n",
|
| 320 |
+
"Epoch 74 Loss: 0.0030\n",
|
| 321 |
+
"Epoch 75 Loss: 0.0032\n",
|
| 322 |
+
"Epoch 76 Loss: 0.0032\n",
|
| 323 |
+
"Epoch 77 Loss: 0.0029\n",
|
| 324 |
+
"Epoch 78 Loss: 0.0026\n",
|
| 325 |
+
"Epoch 79 Loss: 0.0031\n",
|
| 326 |
+
"Epoch 80 Loss: 0.0021\n",
|
| 327 |
+
"Epoch 81 Loss: 0.0048\n",
|
| 328 |
+
"Epoch 82 Loss: 0.0030\n",
|
| 329 |
+
"Epoch 83 Loss: 0.0045\n",
|
| 330 |
+
"Epoch 84 Loss: 0.0045\n",
|
| 331 |
+
"Epoch 85 Loss: 0.0061\n",
|
| 332 |
+
"Epoch 86 Loss: 0.0784\n",
|
| 333 |
+
"Epoch 87 Loss: 0.0706\n",
|
| 334 |
+
"Epoch 88 Loss: 0.0293\n",
|
| 335 |
+
"Epoch 89 Loss: 0.0059\n",
|
| 336 |
+
"Epoch 90 Loss: 0.0028\n",
|
| 337 |
+
"Epoch 91 Loss: 0.0024\n",
|
| 338 |
+
"Epoch 92 Loss: 0.0023\n",
|
| 339 |
+
"Epoch 93 Loss: 0.0020\n",
|
| 340 |
+
"Epoch 94 Loss: 0.0019\n",
|
| 341 |
+
"Epoch 95 Loss: 0.0018\n",
|
| 342 |
+
"Epoch 96 Loss: 0.0015\n",
|
| 343 |
+
"Epoch 97 Loss: 0.0016\n",
|
| 344 |
+
"Epoch 98 Loss: 0.0019\n",
|
| 345 |
+
"Epoch 99 Loss: 0.0014\n",
|
| 346 |
+
"Epoch 100 Loss: 0.0017\n"
|
| 347 |
+
]
|
| 348 |
+
}
|
| 349 |
+
],
|
| 350 |
+
"source": [
|
| 351 |
+
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
| 352 |
+
"cf_dataset = load_dataset('climate_fever', split='test') \n",
|
| 353 |
+
"# Climate-FEVER only has 'test' split publicly available often. I'll use it as our full dataset.\n",
|
| 354 |
+
"cf_split = cf_dataset.train_test_split(test_size=0.2, seed=42)\n",
|
| 355 |
+
"train_dataset = cf_split['train']\n",
|
| 356 |
+
"test_dataset = cf_split['test']\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"class NLIDataset(Dataset):\n",
|
| 359 |
+
" def __init__(self, dataset, tokenizer, max_len=128):\n",
|
| 360 |
+
" self.dataset = dataset\n",
|
| 361 |
+
" self.tokenizer = tokenizer\n",
|
| 362 |
+
" self.max_len = max_len\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" def __getitem__(self, idx):\n",
|
| 365 |
+
" item = self.dataset[idx]\n",
|
| 366 |
+
" premise = item['claim']\n",
|
| 367 |
+
" # Climate-FEVER has 'evidences' list. I take the first evidence text.\n",
|
| 368 |
+
" evidence_data = item['evidences'][0]\n",
|
| 369 |
+
" hypothesis = evidence_data['evidence']\n",
|
| 370 |
+
" label_raw = evidence_data['evidence_label'] # Nested access\n",
|
| 371 |
+
"\n",
|
| 372 |
+
" # Robust label mapping (Handles integers 0/1/2 and strings)\n",
|
| 373 |
+
" # Target: 0: Entailment, 1: Neutral, 2: Contradiction\n",
|
| 374 |
+
" if isinstance(label_raw, int):\n",
|
| 375 |
+
" # Assuming HF Climate-FEVER uses: 0: Supports, 1: Refutes, 2: NEI\n",
|
| 376 |
+
" if label_raw == 0: label = 0 # Supports -> Entailment\n",
|
| 377 |
+
" elif label_raw == 1: label = 2 # Refutes -> Contradiction\n",
|
| 378 |
+
" elif label_raw == 2: label = 1 # NEI -> Neutral\n",
|
| 379 |
+
" else: label = 1 # Default\n",
|
| 380 |
+
" else:\n",
|
| 381 |
+
" label_str = str(label_raw).upper().replace(\" \", \"_\")\n",
|
| 382 |
+
" if 'SUPPORT' in label_str: label = 0\n",
|
| 383 |
+
" elif 'REFUTE' in label_str: label = 2\n",
|
| 384 |
+
" elif 'INFO' in label_str: label = 1\n",
|
| 385 |
+
" else: label = 1\n",
|
| 386 |
+
"\n",
|
| 387 |
+
" encoded_premise = self.tokenizer(\n",
|
| 388 |
+
" premise,\n",
|
| 389 |
+
" add_special_tokens=True,\n",
|
| 390 |
+
" max_length=self.max_len,\n",
|
| 391 |
+
" padding='max_length',\n",
|
| 392 |
+
" return_attention_mask=True,\n",
|
| 393 |
+
" truncation=True\n",
|
| 394 |
+
" )\n",
|
| 395 |
+
"\n",
|
| 396 |
+
" encoded_hypothesis = self.tokenizer(\n",
|
| 397 |
+
" hypothesis,\n",
|
| 398 |
+
" add_special_tokens=True,\n",
|
| 399 |
+
" max_length=self.max_len,\n",
|
| 400 |
+
" padding='max_length',\n",
|
| 401 |
+
" return_attention_mask=True,\n",
|
| 402 |
+
" truncation=True\n",
|
| 403 |
+
" )\n",
|
| 404 |
+
"\n",
|
| 405 |
+
" return {\n",
|
| 406 |
+
" 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
|
| 407 |
+
" 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
|
| 408 |
+
" 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
|
| 409 |
+
" 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
|
| 410 |
+
" 'label': torch.tensor(label, dtype=torch.long)\n",
|
| 411 |
+
" }\n",
|
| 412 |
+
"\n",
|
| 413 |
+
" def __len__(self):\n",
|
| 414 |
+
" return len(self.dataset)\n",
|
| 415 |
+
"\n",
|
| 416 |
+
"train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
|
| 417 |
+
"test_loader = DataLoader(NLIDataset(test_dataset, tokenizer), batch_size=16, shuffle=False)\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"print(\"Starting S-BERT Training...\")\n",
|
| 420 |
+
"for epoch in range(100): \n",
|
| 421 |
+
" s_model.train()\n",
|
| 422 |
+
" total_loss = 0\n",
|
| 423 |
+
" for batch in train_loader:\n",
|
| 424 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 425 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 426 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 427 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 428 |
+
" labels = batch['label'].to(device)\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" optimizer.zero_grad()\n",
|
| 431 |
+
" logits = s_model(p_ids, p_seg, h_ids, h_seg)\n",
|
| 432 |
+
" loss = criterion(logits, labels)\n",
|
| 433 |
+
" loss.backward()\n",
|
| 434 |
+
" optimizer.step()\n",
|
| 435 |
+
" total_loss += loss.item()\n",
|
| 436 |
+
" print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"torch.save(s_model.state_dict(), './models/sbert_climate_fever.pt')"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"cell_type": "markdown",
|
| 443 |
+
"metadata": {},
|
| 444 |
+
"source": [
|
| 445 |
+
"## 3. Evaluation\n",
|
| 446 |
+
"Evaluation of the model on the held-out test set."
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"cell_type": "code",
|
| 451 |
+
"execution_count": 14,
|
| 452 |
+
"metadata": {},
|
| 453 |
+
"outputs": [
|
| 454 |
+
{
|
| 455 |
+
"name": "stdout",
|
| 456 |
+
"output_type": "stream",
|
| 457 |
+
"text": [
|
| 458 |
+
"Evaluating...\n",
|
| 459 |
+
"Classification Report:\n",
|
| 460 |
+
" precision recall f1-score support\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" Entailment 0.26 0.22 0.24 82\n",
|
| 463 |
+
" Neutral 0.62 0.69 0.65 191\n",
|
| 464 |
+
"Contradiction 0.19 0.15 0.17 34\n",
|
| 465 |
+
"\n",
|
| 466 |
+
" accuracy 0.50 307\n",
|
| 467 |
+
" macro avg 0.36 0.35 0.35 307\n",
|
| 468 |
+
" weighted avg 0.48 0.50 0.49 307\n",
|
| 469 |
+
"\n",
|
| 470 |
+
"Accuracy: 0.5049\n"
|
| 471 |
+
]
|
| 472 |
+
}
|
| 473 |
+
],
|
| 474 |
+
"source": [
|
| 475 |
+
"s_model.eval()\n",
|
| 476 |
+
"all_preds = []\n",
|
| 477 |
+
"all_labels = []\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"print(\"Evaluating...\")\n",
|
| 480 |
+
"with torch.no_grad():\n",
|
| 481 |
+
" for batch in test_loader:\n",
|
| 482 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 483 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 484 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 485 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 486 |
+
" labels = batch['label'].to(device)\n",
|
| 487 |
+
"\n",
|
| 488 |
+
" logits = s_model(p_ids, p_seg, h_ids, h_seg)\n",
|
| 489 |
+
" preds = torch.argmax(logits, dim=1)\n",
|
| 490 |
+
" \n",
|
| 491 |
+
" all_preds.extend(preds.cpu().numpy())\n",
|
| 492 |
+
" all_labels.extend(labels.cpu().numpy())\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
|
| 495 |
+
"print(\"Classification Report:\")\n",
|
| 496 |
+
"print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
|
| 497 |
+
"print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
|
| 498 |
+
]
|
| 499 |
+
}
|
| 500 |
+
],
|
| 501 |
+
"metadata": {
|
| 502 |
+
"kernelspec": {
|
| 503 |
+
"display_name": "Python 3",
|
| 504 |
+
"language": "python",
|
| 505 |
+
"name": "python3"
|
| 506 |
+
},
|
| 507 |
+
"language_info": {
|
| 508 |
+
"codemirror_mode": {
|
| 509 |
+
"name": "ipython",
|
| 510 |
+
"version": 3
|
| 511 |
+
},
|
| 512 |
+
"file_extension": ".py",
|
| 513 |
+
"mimetype": "text/x-python",
|
| 514 |
+
"name": "python",
|
| 515 |
+
"nbconvert_exporter": "python",
|
| 516 |
+
"pygments_lexer": "ipython3",
|
| 517 |
+
"version": "3.8.5"
|
| 518 |
+
}
|
| 519 |
+
},
|
| 520 |
+
"nbformat": 4,
|
| 521 |
+
"nbformat_minor": 4
|
| 522 |
+
}
|
A4_Option_MNLI.ipynb
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# A4: S-BERT Training on Alternative Datasets (MNLI)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook allows me to train the S-BERT model on the **MNLI** (Multi-Genre Natural Language Inference) dataset.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"## 1. Environment Setup"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 8,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"name": "stdout",
|
| 21 |
+
"output_type": "stream",
|
| 22 |
+
"text": [
|
| 23 |
+
"Using device: mps\n"
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
],
|
| 27 |
+
"source": [
|
| 28 |
+
"import os\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"import torch.nn as nn\n",
|
| 31 |
+
"import torch.optim as optim\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"from datasets import load_dataset\n",
|
| 34 |
+
"from transformers import BertTokenizer\n",
|
| 35 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 36 |
+
"from sklearn.metrics import classification_report, accuracy_score\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Device Configuration\n",
|
| 39 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
|
| 40 |
+
"print(f\"Using device: {device}\")"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"## 2. Load Pre-trained BERT\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"I will load the BERT model trained in `A4_BERT.ipynb`. Ensure `models/bert_trained.pt` exists."
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 9,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [
|
| 57 |
+
{
|
| 58 |
+
"name": "stdout",
|
| 59 |
+
"output_type": "stream",
|
| 60 |
+
"text": [
|
| 61 |
+
"Loaded bert_trained.pt\n"
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"source": [
|
| 66 |
+
"# Define BERT Architecture\n",
|
| 67 |
+
"# MUST MATCH THE OPTIMIZED CONFIG FROM A4_BERT.ipynb\n",
|
| 68 |
+
"vocab_size = 5004 # Updated from 30522\n",
|
| 69 |
+
"d_model = 256 # MiniBERT Config\n",
|
| 70 |
+
"n_layers = 2 # Updated from 4\n",
|
| 71 |
+
"n_heads = 4\n",
|
| 72 |
+
"d_ff = 256 * 4\n",
|
| 73 |
+
"max_len = 128\n",
|
| 74 |
+
"n_segments = 2\n",
|
| 75 |
+
"d_k = d_v = 64\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"class Embedding(nn.Module):\n",
|
| 78 |
+
" def __init__(self):\n",
|
| 79 |
+
" super(Embedding, self).__init__()\n",
|
| 80 |
+
" self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
|
| 81 |
+
" self.pos_embed = nn.Embedding(max_len, d_model)\n",
|
| 82 |
+
" self.seg_embed = nn.Embedding(n_segments, d_model)\n",
|
| 83 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" def forward(self, x, seg):\n",
|
| 86 |
+
" seq_len = x.size(1)\n",
|
| 87 |
+
" pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
|
| 88 |
+
" pos = pos.unsqueeze(0).expand_as(x)\n",
|
| 89 |
+
" embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
|
| 90 |
+
" return self.norm(embedding)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"class MultiHeadAttention(nn.Module):\n",
|
| 93 |
+
" def __init__(self):\n",
|
| 94 |
+
" super(MultiHeadAttention, self).__init__()\n",
|
| 95 |
+
" self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
|
| 96 |
+
" self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
|
| 97 |
+
" self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
|
| 98 |
+
" self.linear = nn.Linear(n_heads * d_v, d_model)\n",
|
| 99 |
+
" self.layer_norm = nn.LayerNorm(d_model)\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" def forward(self, Q, K, V, attn_mask):\n",
|
| 102 |
+
" batch_size = Q.size(0)\n",
|
| 103 |
+
" q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)\n",
|
| 104 |
+
" k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)\n",
|
| 105 |
+
" v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)\n",
|
| 106 |
+
" \n",
|
| 107 |
+
" attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
|
| 108 |
+
" \n",
|
| 109 |
+
" scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
|
| 110 |
+
" scores.masked_fill_(attn_mask, -1e9)\n",
|
| 111 |
+
" attn = nn.Softmax(dim=-1)(scores)\n",
|
| 112 |
+
" context = torch.matmul(attn, v_s)\n",
|
| 113 |
+
" context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
|
| 114 |
+
" output = self.linear(context)\n",
|
| 115 |
+
" return self.layer_norm(output + Q), attn\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"class PoswiseFeedForwardNet(nn.Module):\n",
|
| 118 |
+
" def __init__(self):\n",
|
| 119 |
+
" super(PoswiseFeedForwardNet, self).__init__()\n",
|
| 120 |
+
" self.fc1 = nn.Linear(d_model, d_ff)\n",
|
| 121 |
+
" self.fc2 = nn.Linear(d_ff, d_model)\n",
|
| 122 |
+
" def forward(self, x):\n",
|
| 123 |
+
" return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"class EncoderLayer(nn.Module):\n",
|
| 126 |
+
" def __init__(self):\n",
|
| 127 |
+
" super(EncoderLayer, self).__init__()\n",
|
| 128 |
+
" self.enc_self_attn = MultiHeadAttention()\n",
|
| 129 |
+
" self.pos_ffn = PoswiseFeedForwardNet()\n",
|
| 130 |
+
" def forward(self, enc_inputs, enc_self_attn_mask):\n",
|
| 131 |
+
" enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
|
| 132 |
+
" enc_outputs = self.pos_ffn(enc_outputs)\n",
|
| 133 |
+
" return enc_outputs, attn\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"def get_attn_pad_mask(seq_q, seq_k):\n",
|
| 136 |
+
" batch_size, len_q = seq_q.size()\n",
|
| 137 |
+
" batch_size, len_k = seq_k.size()\n",
|
| 138 |
+
" pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
|
| 139 |
+
" return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"class BERT(nn.Module):\n",
|
| 142 |
+
" def __init__(self):\n",
|
| 143 |
+
" super(BERT, self).__init__()\n",
|
| 144 |
+
" self.embedding = Embedding()\n",
|
| 145 |
+
" self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
|
| 146 |
+
" self.fc = nn.Linear(d_model, d_model)\n",
|
| 147 |
+
" self.activ = nn.Tanh()\n",
|
| 148 |
+
" self.linear = nn.Linear(d_model, d_model)\n",
|
| 149 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 150 |
+
" self.classifier = nn.Linear(d_model, 2)\n",
|
| 151 |
+
" embed_weight = self.embedding.tok_embed.weight\n",
|
| 152 |
+
" n_vocab, n_dim = embed_weight.size()\n",
|
| 153 |
+
" self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
|
| 154 |
+
" self.decoder.weight = embed_weight\n",
|
| 155 |
+
" self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
|
| 156 |
+
"\n",
|
| 157 |
+
" def forward(self, input_ids, segment_ids, masked_pos=None):\n",
|
| 158 |
+
" output = self.embedding(input_ids, segment_ids)\n",
|
| 159 |
+
" enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
|
| 160 |
+
" for layer in self.layers:\n",
|
| 161 |
+
" output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
|
| 162 |
+
" return None, None, output \n",
|
| 163 |
+
"\n",
|
| 164 |
+
"# Load Pretrained Weights\n",
|
| 165 |
+
"bert = BERT().to(device)\n",
|
| 166 |
+
"try:\n",
|
| 167 |
+
" bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
|
| 168 |
+
" print(\"Loaded bert_trained.pt\")\n",
|
| 169 |
+
"except:\n",
|
| 170 |
+
" print(\"Warning: bert_trained.pt not found. Using random weights.\")\n"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "markdown",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"source": [
|
| 177 |
+
"## 3. Load MNLI Dataset\n"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": 10,
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [
|
| 185 |
+
{
|
| 186 |
+
"name": "stdout",
|
| 187 |
+
"output_type": "stream",
|
| 188 |
+
"text": [
|
| 189 |
+
"Loading mnli...\n",
|
| 190 |
+
"Loaded dataset keys: dict_keys(['train', 'validation_matched', 'validation_mismatched', 'test_matched', 'test_mismatched'])\n",
|
| 191 |
+
"Train size: 10000, Val size: 1000\n"
|
| 192 |
+
]
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
+
"source": [
|
| 196 |
+
"DATASET_NAME = 'mnli'\n",
|
| 197 |
+
"print(f\"Loading {DATASET_NAME}...\")\n",
|
| 198 |
+
"# MNLI is part of GLUE benchmark\n",
|
| 199 |
+
"dataset = load_dataset('glue', 'mnli') \n",
|
| 200 |
+
"print(f\"Loaded dataset keys: {dataset.keys()}\")\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"train_dataset = dataset['train'].select(range(10000))\n",
|
| 203 |
+
"val_dataset = dataset['validation_matched'].select(range(1000))\n",
|
| 204 |
+
"print(f\"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}\")"
|
| 205 |
+
]
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "code",
|
| 209 |
+
"execution_count": 11,
|
| 210 |
+
"metadata": {},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"# Data Loader\n",
|
| 214 |
+
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"class NLIDataset(Dataset):\n",
|
| 217 |
+
" def __init__(self, dataset, tokenizer, max_len=128):\n",
|
| 218 |
+
" self.dataset = dataset\n",
|
| 219 |
+
" self.tokenizer = tokenizer\n",
|
| 220 |
+
" self.max_len = max_len\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" def __len__(self):\n",
|
| 223 |
+
" return len(self.dataset)\n",
|
| 224 |
+
"\n",
|
| 225 |
+
" def __getitem__(self, idx):\n",
|
| 226 |
+
" item = self.dataset[idx]\n",
|
| 227 |
+
" premise = item['premise']\n",
|
| 228 |
+
" hypothesis = item['hypothesis']\n",
|
| 229 |
+
" label = item['label']\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" encoded_premise = self.tokenizer(\n",
|
| 232 |
+
" premise,\n",
|
| 233 |
+
" add_special_tokens=True,\n",
|
| 234 |
+
" max_length=self.max_len,\n",
|
| 235 |
+
" padding='max_length',\n",
|
| 236 |
+
" return_attention_mask=True,\n",
|
| 237 |
+
" truncation=True\n",
|
| 238 |
+
" )\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" encoded_hypothesis = self.tokenizer(\n",
|
| 241 |
+
" hypothesis,\n",
|
| 242 |
+
" add_special_tokens=True,\n",
|
| 243 |
+
" max_length=self.max_len,\n",
|
| 244 |
+
" padding='max_length',\n",
|
| 245 |
+
" return_attention_mask=True,\n",
|
| 246 |
+
" truncation=True\n",
|
| 247 |
+
" )\n",
|
| 248 |
+
"\n",
|
| 249 |
+
" return {\n",
|
| 250 |
+
" 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
|
| 251 |
+
" 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
|
| 252 |
+
" 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
|
| 253 |
+
" 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
|
| 254 |
+
" 'label': torch.tensor(label, dtype=torch.long)\n",
|
| 255 |
+
" }\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
|
| 258 |
+
"test_loader = DataLoader(NLIDataset(val_dataset, tokenizer), batch_size=16, shuffle=False)"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": 12,
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"# S-BERT Model\n",
|
| 268 |
+
"class SBERT(nn.Module):\n",
|
| 269 |
+
" def __init__(self, bert_model):\n",
|
| 270 |
+
" super(SBERT, self).__init__()\n",
|
| 271 |
+
" self.bert = bert_model\n",
|
| 272 |
+
" self.classifier = nn.Linear(d_model * 3, 3)\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
|
| 275 |
+
" device = premise_ids.device\n",
|
| 276 |
+
" dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
|
| 277 |
+
" \n",
|
| 278 |
+
" _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
|
| 279 |
+
" mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
|
| 280 |
+
" u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
|
| 281 |
+
"\n",
|
| 282 |
+
" _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
|
| 283 |
+
" mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
|
| 284 |
+
" v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" uv_abs = torch.abs(u - v)\n",
|
| 287 |
+
" features = torch.cat([u, v, uv_abs], dim=-1)\n",
|
| 288 |
+
" logits = self.classifier(features)\n",
|
| 289 |
+
" return logits\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"sbert = SBERT(bert).to(device)\n",
|
| 292 |
+
"optimizer = optim.Adam(sbert.parameters(), lr=2e-5)\n",
|
| 293 |
+
"criterion = nn.CrossEntropyLoss()"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "code",
|
| 298 |
+
"execution_count": 13,
|
| 299 |
+
"metadata": {},
|
| 300 |
+
"outputs": [
|
| 301 |
+
{
|
| 302 |
+
"name": "stdout",
|
| 303 |
+
"output_type": "stream",
|
| 304 |
+
"text": [
|
| 305 |
+
"Starting Training...\n",
|
| 306 |
+
"Epoch 1 Loss: 1.0874\n",
|
| 307 |
+
"Epoch 2 Loss: 1.0579\n",
|
| 308 |
+
"Epoch 3 Loss: 1.0135\n",
|
| 309 |
+
"Epoch 4 Loss: 0.9806\n",
|
| 310 |
+
"Epoch 5 Loss: 0.9516\n",
|
| 311 |
+
"Epoch 6 Loss: 0.9264\n",
|
| 312 |
+
"Epoch 7 Loss: 0.8967\n",
|
| 313 |
+
"Epoch 8 Loss: 0.8670\n",
|
| 314 |
+
"Epoch 9 Loss: 0.8326\n",
|
| 315 |
+
"Epoch 10 Loss: 0.8005\n",
|
| 316 |
+
"Epoch 11 Loss: 0.7601\n",
|
| 317 |
+
"Epoch 12 Loss: 0.7218\n",
|
| 318 |
+
"Epoch 13 Loss: 0.6797\n",
|
| 319 |
+
"Epoch 14 Loss: 0.6349\n",
|
| 320 |
+
"Epoch 15 Loss: 0.5906\n",
|
| 321 |
+
"Epoch 16 Loss: 0.5450\n",
|
| 322 |
+
"Epoch 17 Loss: 0.4974\n",
|
| 323 |
+
"Epoch 18 Loss: 0.4493\n",
|
| 324 |
+
"Epoch 19 Loss: 0.4065\n",
|
| 325 |
+
"Epoch 20 Loss: 0.3590\n",
|
| 326 |
+
"Epoch 21 Loss: 0.3150\n",
|
| 327 |
+
"Epoch 22 Loss: 0.2712\n",
|
| 328 |
+
"Epoch 23 Loss: 0.2321\n",
|
| 329 |
+
"Epoch 24 Loss: 0.2020\n",
|
| 330 |
+
"Epoch 25 Loss: 0.1673\n",
|
| 331 |
+
"Epoch 26 Loss: 0.1349\n",
|
| 332 |
+
"Epoch 27 Loss: 0.1136\n",
|
| 333 |
+
"Epoch 28 Loss: 0.0965\n",
|
| 334 |
+
"Epoch 29 Loss: 0.0876\n",
|
| 335 |
+
"Epoch 30 Loss: 0.0723\n",
|
| 336 |
+
"Epoch 31 Loss: 0.0656\n",
|
| 337 |
+
"Epoch 32 Loss: 0.0504\n",
|
| 338 |
+
"Epoch 33 Loss: 0.0416\n",
|
| 339 |
+
"Epoch 34 Loss: 0.0418\n",
|
| 340 |
+
"Epoch 35 Loss: 0.0327\n",
|
| 341 |
+
"Epoch 36 Loss: 0.0324\n",
|
| 342 |
+
"Epoch 37 Loss: 0.0269\n",
|
| 343 |
+
"Epoch 38 Loss: 0.0438\n",
|
| 344 |
+
"Epoch 39 Loss: 0.0291\n",
|
| 345 |
+
"Epoch 40 Loss: 0.0210\n",
|
| 346 |
+
"Epoch 41 Loss: 0.0168\n",
|
| 347 |
+
"Epoch 42 Loss: 0.0309\n",
|
| 348 |
+
"Epoch 43 Loss: 0.0180\n",
|
| 349 |
+
"Epoch 44 Loss: 0.0327\n",
|
| 350 |
+
"Epoch 45 Loss: 0.0411\n",
|
| 351 |
+
"Epoch 46 Loss: 0.0157\n",
|
| 352 |
+
"Epoch 47 Loss: 0.0048\n",
|
| 353 |
+
"Epoch 48 Loss: 0.0019\n",
|
| 354 |
+
"Epoch 49 Loss: 0.0013\n",
|
| 355 |
+
"Epoch 50 Loss: 0.0010\n",
|
| 356 |
+
"Epoch 51 Loss: 0.0008\n",
|
| 357 |
+
"Epoch 52 Loss: 0.0007\n",
|
| 358 |
+
"Epoch 53 Loss: 0.0005\n",
|
| 359 |
+
"Epoch 54 Loss: 0.0004\n",
|
| 360 |
+
"Epoch 55 Loss: 0.0003\n",
|
| 361 |
+
"Epoch 56 Loss: 0.0002\n",
|
| 362 |
+
"Epoch 57 Loss: 0.0002\n",
|
| 363 |
+
"Epoch 58 Loss: 0.0001\n",
|
| 364 |
+
"Epoch 59 Loss: 0.0001\n",
|
| 365 |
+
"Epoch 60 Loss: 0.0001\n",
|
| 366 |
+
"Epoch 61 Loss: 0.0000\n",
|
| 367 |
+
"Epoch 62 Loss: 0.0000\n",
|
| 368 |
+
"Epoch 63 Loss: 0.0000\n",
|
| 369 |
+
"Epoch 64 Loss: 0.0000\n",
|
| 370 |
+
"Epoch 65 Loss: 0.0000\n",
|
| 371 |
+
"Epoch 66 Loss: 0.1953\n",
|
| 372 |
+
"Epoch 67 Loss: 0.0272\n",
|
| 373 |
+
"Epoch 68 Loss: 0.0120\n",
|
| 374 |
+
"Epoch 69 Loss: 0.0108\n",
|
| 375 |
+
"Epoch 70 Loss: 0.0152\n",
|
| 376 |
+
"Epoch 71 Loss: 0.0337\n",
|
| 377 |
+
"Epoch 72 Loss: 0.0215\n",
|
| 378 |
+
"Epoch 73 Loss: 0.0148\n",
|
| 379 |
+
"Epoch 74 Loss: 0.0207\n",
|
| 380 |
+
"Epoch 75 Loss: 0.0238\n",
|
| 381 |
+
"Epoch 76 Loss: 0.0181\n",
|
| 382 |
+
"Epoch 77 Loss: 0.0217\n",
|
| 383 |
+
"Epoch 78 Loss: 0.0136\n",
|
| 384 |
+
"Epoch 79 Loss: 0.0163\n",
|
| 385 |
+
"Epoch 80 Loss: 0.0067\n",
|
| 386 |
+
"Epoch 81 Loss: 0.0007\n",
|
| 387 |
+
"Epoch 82 Loss: 0.0003\n",
|
| 388 |
+
"Epoch 83 Loss: 0.0002\n",
|
| 389 |
+
"Epoch 84 Loss: 0.0002\n",
|
| 390 |
+
"Epoch 85 Loss: 0.0001\n",
|
| 391 |
+
"Epoch 86 Loss: 0.0001\n",
|
| 392 |
+
"Epoch 87 Loss: 0.0001\n",
|
| 393 |
+
"Epoch 88 Loss: 0.0001\n",
|
| 394 |
+
"Epoch 89 Loss: 0.0001\n",
|
| 395 |
+
"Epoch 90 Loss: 0.0000\n",
|
| 396 |
+
"Epoch 91 Loss: 0.0000\n",
|
| 397 |
+
"Epoch 92 Loss: 0.0000\n",
|
| 398 |
+
"Epoch 93 Loss: 0.0000\n",
|
| 399 |
+
"Epoch 94 Loss: 0.0000\n",
|
| 400 |
+
"Epoch 95 Loss: 0.0000\n",
|
| 401 |
+
"Epoch 96 Loss: 0.0000\n",
|
| 402 |
+
"Epoch 97 Loss: 0.0000\n",
|
| 403 |
+
"Epoch 98 Loss: 0.0000\n",
|
| 404 |
+
"Epoch 99 Loss: 0.0000\n",
|
| 405 |
+
"Epoch 100 Loss: 0.0000\n",
|
| 406 |
+
"Done!\n"
|
| 407 |
+
]
|
| 408 |
+
}
|
| 409 |
+
],
|
| 410 |
+
"source": [
|
| 411 |
+
"# Train Loop\n",
|
| 412 |
+
"print(\"Starting Training...\")\n",
|
| 413 |
+
"epochs = 100 \n",
|
| 414 |
+
"for epoch in range(epochs):\n",
|
| 415 |
+
" sbert.train()\n",
|
| 416 |
+
" total_loss = 0\n",
|
| 417 |
+
" for batch in train_loader:\n",
|
| 418 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 419 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 420 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 421 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 422 |
+
" labels = batch['label'].to(device)\n",
|
| 423 |
+
"\n",
|
| 424 |
+
" optimizer.zero_grad()\n",
|
| 425 |
+
" logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
|
| 426 |
+
" loss = criterion(logits, labels)\n",
|
| 427 |
+
" loss.backward()\n",
|
| 428 |
+
" optimizer.step()\n",
|
| 429 |
+
" total_loss += loss.item()\n",
|
| 430 |
+
" print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"print(\"Done!\")\n",
|
| 433 |
+
"torch.save(sbert.state_dict(), f'./models/sbert_{DATASET_NAME}.pt')"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "markdown",
|
| 438 |
+
"metadata": {},
|
| 439 |
+
"source": [
|
| 440 |
+
"## 4. Evaluation\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"Evaluate on validation set (matched)."
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"cell_type": "code",
|
| 447 |
+
"execution_count": 14,
|
| 448 |
+
"metadata": {},
|
| 449 |
+
"outputs": [
|
| 450 |
+
{
|
| 451 |
+
"name": "stdout",
|
| 452 |
+
"output_type": "stream",
|
| 453 |
+
"text": [
|
| 454 |
+
"Evaluating...\n",
|
| 455 |
+
"Classification Report:\n",
|
| 456 |
+
" precision recall f1-score support\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" Entailment 0.42 0.45 0.43 341\n",
|
| 459 |
+
" Neutral 0.42 0.34 0.37 319\n",
|
| 460 |
+
"Contradiction 0.49 0.54 0.51 340\n",
|
| 461 |
+
"\n",
|
| 462 |
+
" accuracy 0.44 1000\n",
|
| 463 |
+
" macro avg 0.44 0.44 0.44 1000\n",
|
| 464 |
+
" weighted avg 0.44 0.44 0.44 1000\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"Accuracy: 0.4440\n"
|
| 467 |
+
]
|
| 468 |
+
}
|
| 469 |
+
],
|
| 470 |
+
"source": [
|
| 471 |
+
"sbert.eval()\n",
|
| 472 |
+
"all_preds = []\n",
|
| 473 |
+
"all_labels = []\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"print(\"Evaluating...\")\n",
|
| 476 |
+
"with torch.no_grad():\n",
|
| 477 |
+
" for batch in test_loader:\n",
|
| 478 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 479 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 480 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 481 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 482 |
+
" labels = batch['label'].to(device)\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
|
| 485 |
+
" preds = torch.argmax(logits, dim=1)\n",
|
| 486 |
+
" \n",
|
| 487 |
+
" all_preds.extend(preds.cpu().numpy())\n",
|
| 488 |
+
" all_labels.extend(labels.cpu().numpy())\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
|
| 491 |
+
"print(\"Classification Report:\")\n",
|
| 492 |
+
"print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
|
| 493 |
+
"print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
|
| 494 |
+
]
|
| 495 |
+
}
|
| 496 |
+
],
|
| 497 |
+
"metadata": {
|
| 498 |
+
"kernelspec": {
|
| 499 |
+
"display_name": "Python 3",
|
| 500 |
+
"language": "python",
|
| 501 |
+
"name": "python3"
|
| 502 |
+
},
|
| 503 |
+
"language_info": {
|
| 504 |
+
"codemirror_mode": {
|
| 505 |
+
"name": "ipython",
|
| 506 |
+
"version": 3
|
| 507 |
+
},
|
| 508 |
+
"file_extension": ".py",
|
| 509 |
+
"mimetype": "text/x-python",
|
| 510 |
+
"name": "python",
|
| 511 |
+
"nbconvert_exporter": "python",
|
| 512 |
+
"pygments_lexer": "ipython3",
|
| 513 |
+
"version": "3.8.5"
|
| 514 |
+
}
|
| 515 |
+
},
|
| 516 |
+
"nbformat": 4,
|
| 517 |
+
"nbformat_minor": 4
|
| 518 |
+
}
|
A4_Option_SNLI.ipynb
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# A4: S-BERT Training on Alternative Datasets (SNLI)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"This notebook allows me to train the S-BERT model on the **SNLI** (Stanford Natural Language Inference) dataset.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"## 1. Environment Setup"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 9,
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"name": "stdout",
|
| 21 |
+
"output_type": "stream",
|
| 22 |
+
"text": [
|
| 23 |
+
"Using device: mps\n"
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
],
|
| 27 |
+
"source": [
|
| 28 |
+
"import os\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"import torch.nn as nn\n",
|
| 31 |
+
"import torch.optim as optim\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"from datasets import load_dataset\n",
|
| 34 |
+
"from transformers import BertTokenizer\n",
|
| 35 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 36 |
+
"from sklearn.metrics import classification_report, accuracy_score\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"# Device Configuration\n",
|
| 39 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else (\"mps\" if torch.backends.mps.is_available() else \"cpu\"))\n",
|
| 40 |
+
"print(f\"Using device: {device}\")"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "markdown",
|
| 45 |
+
"metadata": {},
|
| 46 |
+
"source": [
|
| 47 |
+
"## 2. Load Pre-trained BERT\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"I will load the BERT model trained in `A4_BERT.ipynb`. Ensure `models/bert_trained.pt` exists."
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 10,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [
|
| 57 |
+
{
|
| 58 |
+
"name": "stdout",
|
| 59 |
+
"output_type": "stream",
|
| 60 |
+
"text": [
|
| 61 |
+
"Loaded bert_trained.pt\n"
|
| 62 |
+
]
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
"source": [
|
| 66 |
+
"# Define BERT Architecture\n",
|
| 67 |
+
"# MUST MATCH THE OPTIMIZED CONFIG FROM A4_BERT.ipynb\n",
|
| 68 |
+
"vocab_size = 5004 # Updated from 30522\n",
|
| 69 |
+
"d_model = 256 # MiniBERT Config\n",
|
| 70 |
+
"n_layers = 2 # Updated from 4\n",
|
| 71 |
+
"n_heads = 4\n",
|
| 72 |
+
"d_ff = 256 * 4\n",
|
| 73 |
+
"max_len = 128\n",
|
| 74 |
+
"n_segments = 2\n",
|
| 75 |
+
"d_k = d_v = 64\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"class Embedding(nn.Module):\n",
|
| 78 |
+
" def __init__(self):\n",
|
| 79 |
+
" super(Embedding, self).__init__()\n",
|
| 80 |
+
" self.tok_embed = nn.Embedding(vocab_size, d_model)\n",
|
| 81 |
+
" self.pos_embed = nn.Embedding(max_len, d_model)\n",
|
| 82 |
+
" self.seg_embed = nn.Embedding(n_segments, d_model)\n",
|
| 83 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" def forward(self, x, seg):\n",
|
| 86 |
+
" seq_len = x.size(1)\n",
|
| 87 |
+
" pos = torch.arange(seq_len, dtype=torch.long, device=x.device)\n",
|
| 88 |
+
" pos = pos.unsqueeze(0).expand_as(x)\n",
|
| 89 |
+
" embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
|
| 90 |
+
" return self.norm(embedding)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"class MultiHeadAttention(nn.Module):\n",
|
| 93 |
+
" def __init__(self):\n",
|
| 94 |
+
" super(MultiHeadAttention, self).__init__()\n",
|
| 95 |
+
" self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
|
| 96 |
+
" self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
|
| 97 |
+
" self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
|
| 98 |
+
" self.linear = nn.Linear(n_heads * d_v, d_model)\n",
|
| 99 |
+
" self.layer_norm = nn.LayerNorm(d_model)\n",
|
| 100 |
+
"\n",
|
| 101 |
+
" def forward(self, Q, K, V, attn_mask):\n",
|
| 102 |
+
" batch_size = Q.size(0)\n",
|
| 103 |
+
" q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 104 |
+
" k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)\n",
|
| 105 |
+
" v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)\n",
|
| 106 |
+
" \n",
|
| 107 |
+
" attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
|
| 108 |
+
" \n",
|
| 109 |
+
" scores = torch.matmul(q_s, k_s.transpose(-1, -2)) / np.sqrt(d_k)\n",
|
| 110 |
+
" scores.masked_fill_(attn_mask, -1e9)\n",
|
| 111 |
+
" attn = nn.Softmax(dim=-1)(scores)\n",
|
| 112 |
+
" context = torch.matmul(attn, v_s)\n",
|
| 113 |
+
" context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)\n",
|
| 114 |
+
" output = self.linear(context)\n",
|
| 115 |
+
" return self.layer_norm(output + Q), attn\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"class PoswiseFeedForwardNet(nn.Module):\n",
|
| 118 |
+
" def __init__(self):\n",
|
| 119 |
+
" super(PoswiseFeedForwardNet, self).__init__()\n",
|
| 120 |
+
" self.fc1 = nn.Linear(d_model, d_ff)\n",
|
| 121 |
+
" self.fc2 = nn.Linear(d_ff, d_model)\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" def forward(self, x):\n",
|
| 124 |
+
" return self.fc2(torch.nn.functional.gelu(self.fc1(x)))\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"class EncoderLayer(nn.Module):\n",
|
| 127 |
+
" def __init__(self):\n",
|
| 128 |
+
" super(EncoderLayer, self).__init__()\n",
|
| 129 |
+
" self.enc_self_attn = MultiHeadAttention()\n",
|
| 130 |
+
" self.pos_ffn = PoswiseFeedForwardNet()\n",
|
| 131 |
+
"\n",
|
| 132 |
+
" def forward(self, enc_inputs, enc_self_attn_mask):\n",
|
| 133 |
+
" enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)\n",
|
| 134 |
+
" enc_outputs = self.pos_ffn(enc_outputs)\n",
|
| 135 |
+
" return enc_outputs, attn\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"def get_attn_pad_mask(seq_q, seq_k):\n",
|
| 138 |
+
" batch_size, len_q = seq_q.size()\n",
|
| 139 |
+
" batch_size, len_k = seq_k.size()\n",
|
| 140 |
+
" pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)\n",
|
| 141 |
+
" return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"class BERT(nn.Module):\n",
|
| 144 |
+
" def __init__(self):\n",
|
| 145 |
+
" super(BERT, self).__init__()\n",
|
| 146 |
+
" self.embedding = Embedding()\n",
|
| 147 |
+
" self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
|
| 148 |
+
" self.fc = nn.Linear(d_model, d_model)\n",
|
| 149 |
+
" self.activ = nn.Tanh()\n",
|
| 150 |
+
" self.linear = nn.Linear(d_model, d_model)\n",
|
| 151 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 152 |
+
" self.classifier = nn.Linear(d_model, 2)\n",
|
| 153 |
+
" embed_weight = self.embedding.tok_embed.weight\n",
|
| 154 |
+
" n_vocab, n_dim = embed_weight.size()\n",
|
| 155 |
+
" self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
|
| 156 |
+
" self.decoder.weight = embed_weight\n",
|
| 157 |
+
" self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" def forward(self, input_ids, segment_ids, masked_pos=None):\n",
|
| 160 |
+
" output = self.embedding(input_ids, segment_ids)\n",
|
| 161 |
+
" enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
|
| 162 |
+
" for layer in self.layers:\n",
|
| 163 |
+
" output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
|
| 164 |
+
" return None, None, output \n",
|
| 165 |
+
"\n",
|
| 166 |
+
"# Load Pretrained Weights\n",
|
| 167 |
+
"bert = BERT().to(device)\n",
|
| 168 |
+
"try:\n",
|
| 169 |
+
" bert.load_state_dict(torch.load('./models/bert_trained.pt', map_location=device))\n",
|
| 170 |
+
" print(\"Loaded bert_trained.pt\")\n",
|
| 171 |
+
"except:\n",
|
| 172 |
+
" print(\"Warning: bert_trained.pt not found. Using random weights.\")\n"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "markdown",
|
| 177 |
+
"metadata": {},
|
| 178 |
+
"source": [
|
| 179 |
+
"## 3. Load SNLI Dataset\n"
|
| 180 |
+
]
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"cell_type": "code",
|
| 184 |
+
"execution_count": 11,
|
| 185 |
+
"metadata": {},
|
| 186 |
+
"outputs": [
|
| 187 |
+
{
|
| 188 |
+
"name": "stdout",
|
| 189 |
+
"output_type": "stream",
|
| 190 |
+
"text": [
|
| 191 |
+
"Loading snli...\n"
|
| 192 |
+
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"name": "stderr",
|
| 196 |
+
"output_type": "stream",
|
| 197 |
+
"text": [
|
| 198 |
+
"Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"name": "stdout",
|
| 203 |
+
"output_type": "stream",
|
| 204 |
+
"text": [
|
| 205 |
+
"Loaded dataset keys: dict_keys(['test', 'validation', 'train'])\n",
|
| 206 |
+
"Train size: 9988, Test size: 988\n"
|
| 207 |
+
]
|
| 208 |
+
}
|
| 209 |
+
],
|
| 210 |
+
"source": [
|
| 211 |
+
"DATASET_NAME = 'snli'\n",
|
| 212 |
+
"print(f\"Loading {DATASET_NAME}...\")\n",
|
| 213 |
+
"dataset = load_dataset(DATASET_NAME)\n",
|
| 214 |
+
"print(f\"Loaded dataset keys: {dataset.keys()}\")\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"train_dataset = dataset['train'].select(range(10000))\n",
|
| 217 |
+
"test_dataset = dataset['test'].select(range(1000))\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"# Filter undefined labels\n",
|
| 220 |
+
"train_dataset = train_dataset.filter(lambda x: x['label'] != -1)\n",
|
| 221 |
+
"test_dataset = test_dataset.filter(lambda x: x['label'] != -1)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"print(f\"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}\")"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": 12,
|
| 229 |
+
"metadata": {},
|
| 230 |
+
"outputs": [],
|
| 231 |
+
"source": [
|
| 232 |
+
"# Data Loader\n",
|
| 233 |
+
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"class NLIDataset(Dataset):\n",
|
| 236 |
+
" def __init__(self, dataset, tokenizer, max_len=128):\n",
|
| 237 |
+
" self.dataset = dataset\n",
|
| 238 |
+
" self.tokenizer = tokenizer\n",
|
| 239 |
+
" self.max_len = max_len\n",
|
| 240 |
+
"\n",
|
| 241 |
+
" def __len__(self):\n",
|
| 242 |
+
" return len(self.dataset)\n",
|
| 243 |
+
"\n",
|
| 244 |
+
" def __getitem__(self, idx):\n",
|
| 245 |
+
" item = self.dataset[idx]\n",
|
| 246 |
+
" premise = item['premise']\n",
|
| 247 |
+
" hypothesis = item['hypothesis']\n",
|
| 248 |
+
" label = item['label']\n",
|
| 249 |
+
"\n",
|
| 250 |
+
" encoded_premise = self.tokenizer(\n",
|
| 251 |
+
" premise,\n",
|
| 252 |
+
" add_special_tokens=True,\n",
|
| 253 |
+
" max_length=self.max_len,\n",
|
| 254 |
+
" padding='max_length',\n",
|
| 255 |
+
" return_attention_mask=True,\n",
|
| 256 |
+
" truncation=True\n",
|
| 257 |
+
" )\n",
|
| 258 |
+
"\n",
|
| 259 |
+
" encoded_hypothesis = self.tokenizer(\n",
|
| 260 |
+
" hypothesis,\n",
|
| 261 |
+
" add_special_tokens=True,\n",
|
| 262 |
+
" max_length=self.max_len,\n",
|
| 263 |
+
" padding='max_length',\n",
|
| 264 |
+
" return_attention_mask=True,\n",
|
| 265 |
+
" truncation=True\n",
|
| 266 |
+
" )\n",
|
| 267 |
+
"\n",
|
| 268 |
+
" return {\n",
|
| 269 |
+
" 'premise_input_ids': torch.tensor(encoded_premise['input_ids'], dtype=torch.long),\n",
|
| 270 |
+
" 'premise_segment_ids': torch.tensor(encoded_premise['token_type_ids'], dtype=torch.long),\n",
|
| 271 |
+
" 'hypothesis_input_ids': torch.tensor(encoded_hypothesis['input_ids'], dtype=torch.long),\n",
|
| 272 |
+
" 'hypothesis_segment_ids': torch.tensor(encoded_hypothesis['token_type_ids'], dtype=torch.long),\n",
|
| 273 |
+
" 'label': torch.tensor(label, dtype=torch.long)\n",
|
| 274 |
+
" }\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"train_loader = DataLoader(NLIDataset(train_dataset, tokenizer), batch_size=16, shuffle=True)\n",
|
| 277 |
+
"test_loader = DataLoader(NLIDataset(test_dataset, tokenizer), batch_size=16, shuffle=False)"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"cell_type": "code",
|
| 282 |
+
"execution_count": 13,
|
| 283 |
+
"metadata": {},
|
| 284 |
+
"outputs": [],
|
| 285 |
+
"source": [
|
| 286 |
+
"# S-BERT Model\n",
|
| 287 |
+
"class SBERT(nn.Module):\n",
|
| 288 |
+
" def __init__(self, bert_model):\n",
|
| 289 |
+
" super(SBERT, self).__init__()\n",
|
| 290 |
+
" self.bert = bert_model\n",
|
| 291 |
+
" self.classifier = nn.Linear(d_model * 3, 3)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):\n",
|
| 294 |
+
" device = premise_ids.device\n",
|
| 295 |
+
" dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(device)\n",
|
| 296 |
+
" \n",
|
| 297 |
+
" _, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)\n",
|
| 298 |
+
" mask_u = (premise_ids != 0).unsqueeze(-1).float()\n",
|
| 299 |
+
" u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)\n",
|
| 300 |
+
"\n",
|
| 301 |
+
" _, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)\n",
|
| 302 |
+
" mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()\n",
|
| 303 |
+
" v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)\n",
|
| 304 |
+
"\n",
|
| 305 |
+
" uv_abs = torch.abs(u - v)\n",
|
| 306 |
+
" features = torch.cat([u, v, uv_abs], dim=-1)\n",
|
| 307 |
+
" logits = self.classifier(features)\n",
|
| 308 |
+
" return logits\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"sbert = SBERT(bert).to(device)\n",
|
| 311 |
+
"optimizer = optim.Adam(sbert.parameters(), lr=2e-5)\n",
|
| 312 |
+
"criterion = nn.CrossEntropyLoss()"
|
| 313 |
+
]
|
| 314 |
+
},
|
| 315 |
+
{
|
| 316 |
+
"cell_type": "code",
|
| 317 |
+
"execution_count": 14,
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [
|
| 320 |
+
{
|
| 321 |
+
"name": "stdout",
|
| 322 |
+
"output_type": "stream",
|
| 323 |
+
"text": [
|
| 324 |
+
"Starting Training...\n",
|
| 325 |
+
"Epoch 1 Loss: 1.0723\n",
|
| 326 |
+
"Epoch 2 Loss: 1.0245\n",
|
| 327 |
+
"Epoch 3 Loss: 0.9913\n",
|
| 328 |
+
"Epoch 4 Loss: 0.9702\n",
|
| 329 |
+
"Epoch 5 Loss: 0.9487\n",
|
| 330 |
+
"Epoch 6 Loss: 0.9257\n",
|
| 331 |
+
"Epoch 7 Loss: 0.9018\n",
|
| 332 |
+
"Epoch 8 Loss: 0.8794\n",
|
| 333 |
+
"Epoch 9 Loss: 0.8573\n",
|
| 334 |
+
"Epoch 10 Loss: 0.8355\n",
|
| 335 |
+
"Epoch 11 Loss: 0.8129\n",
|
| 336 |
+
"Epoch 12 Loss: 0.7907\n",
|
| 337 |
+
"Epoch 13 Loss: 0.7694\n",
|
| 338 |
+
"Epoch 14 Loss: 0.7470\n",
|
| 339 |
+
"Epoch 15 Loss: 0.7227\n",
|
| 340 |
+
"Epoch 16 Loss: 0.7023\n",
|
| 341 |
+
"Epoch 17 Loss: 0.6780\n",
|
| 342 |
+
"Epoch 18 Loss: 0.6569\n",
|
| 343 |
+
"Epoch 19 Loss: 0.6336\n",
|
| 344 |
+
"Epoch 20 Loss: 0.6084\n",
|
| 345 |
+
"Epoch 21 Loss: 0.5883\n",
|
| 346 |
+
"Epoch 22 Loss: 0.5596\n",
|
| 347 |
+
"Epoch 23 Loss: 0.5356\n",
|
| 348 |
+
"Epoch 24 Loss: 0.5116\n",
|
| 349 |
+
"Epoch 25 Loss: 0.4880\n",
|
| 350 |
+
"Epoch 26 Loss: 0.4623\n",
|
| 351 |
+
"Epoch 27 Loss: 0.4392\n",
|
| 352 |
+
"Epoch 28 Loss: 0.4161\n",
|
| 353 |
+
"Epoch 29 Loss: 0.3903\n",
|
| 354 |
+
"Epoch 30 Loss: 0.3692\n",
|
| 355 |
+
"Epoch 31 Loss: 0.3509\n",
|
| 356 |
+
"Epoch 32 Loss: 0.3258\n",
|
| 357 |
+
"Epoch 33 Loss: 0.3048\n",
|
| 358 |
+
"Epoch 34 Loss: 0.2834\n",
|
| 359 |
+
"Epoch 35 Loss: 0.2664\n",
|
| 360 |
+
"Epoch 36 Loss: 0.2493\n",
|
| 361 |
+
"Epoch 37 Loss: 0.2327\n",
|
| 362 |
+
"Epoch 38 Loss: 0.2145\n",
|
| 363 |
+
"Epoch 39 Loss: 0.2049\n",
|
| 364 |
+
"Epoch 40 Loss: 0.1845\n",
|
| 365 |
+
"Epoch 41 Loss: 0.1687\n",
|
| 366 |
+
"Epoch 42 Loss: 0.1627\n",
|
| 367 |
+
"Epoch 43 Loss: 0.1548\n",
|
| 368 |
+
"Epoch 44 Loss: 0.1367\n",
|
| 369 |
+
"Epoch 45 Loss: 0.1268\n",
|
| 370 |
+
"Epoch 46 Loss: 0.1315\n",
|
| 371 |
+
"Epoch 47 Loss: 0.1230\n",
|
| 372 |
+
"Epoch 48 Loss: 0.1051\n",
|
| 373 |
+
"Epoch 49 Loss: 0.0964\n",
|
| 374 |
+
"Epoch 50 Loss: 0.1027\n",
|
| 375 |
+
"Epoch 51 Loss: 0.0983\n",
|
| 376 |
+
"Epoch 52 Loss: 0.0781\n",
|
| 377 |
+
"Epoch 53 Loss: 0.0795\n",
|
| 378 |
+
"Epoch 54 Loss: 0.0860\n",
|
| 379 |
+
"Epoch 55 Loss: 0.0800\n",
|
| 380 |
+
"Epoch 56 Loss: 0.0620\n",
|
| 381 |
+
"Epoch 57 Loss: 0.0905\n",
|
| 382 |
+
"Epoch 58 Loss: 0.0567\n",
|
| 383 |
+
"Epoch 59 Loss: 0.0568\n",
|
| 384 |
+
"Epoch 60 Loss: 0.0502\n",
|
| 385 |
+
"Epoch 61 Loss: 0.0808\n",
|
| 386 |
+
"Epoch 62 Loss: 0.0622\n",
|
| 387 |
+
"Epoch 63 Loss: 0.0445\n",
|
| 388 |
+
"Epoch 64 Loss: 0.0536\n",
|
| 389 |
+
"Epoch 65 Loss: 0.0564\n",
|
| 390 |
+
"Epoch 66 Loss: 0.0542\n",
|
| 391 |
+
"Epoch 67 Loss: 0.0537\n",
|
| 392 |
+
"Epoch 68 Loss: 0.0419\n",
|
| 393 |
+
"Epoch 69 Loss: 0.0648\n",
|
| 394 |
+
"Epoch 70 Loss: 0.0496\n",
|
| 395 |
+
"Epoch 71 Loss: 0.0510\n",
|
| 396 |
+
"Epoch 72 Loss: 0.0470\n",
|
| 397 |
+
"Epoch 73 Loss: 0.0446\n",
|
| 398 |
+
"Epoch 74 Loss: 0.0359\n",
|
| 399 |
+
"Epoch 75 Loss: 0.0533\n",
|
| 400 |
+
"Epoch 76 Loss: 0.0611\n",
|
| 401 |
+
"Epoch 77 Loss: 0.0368\n",
|
| 402 |
+
"Epoch 78 Loss: 0.0291\n",
|
| 403 |
+
"Epoch 79 Loss: 0.0321\n",
|
| 404 |
+
"Epoch 80 Loss: 0.0757\n",
|
| 405 |
+
"Epoch 81 Loss: 0.0546\n",
|
| 406 |
+
"Epoch 82 Loss: 0.0300\n",
|
| 407 |
+
"Epoch 83 Loss: 0.0279\n",
|
| 408 |
+
"Epoch 84 Loss: 0.0294\n",
|
| 409 |
+
"Epoch 85 Loss: 0.0542\n",
|
| 410 |
+
"Epoch 86 Loss: 0.0422\n",
|
| 411 |
+
"Epoch 87 Loss: 0.0353\n",
|
| 412 |
+
"Epoch 88 Loss: 0.0537\n",
|
| 413 |
+
"Epoch 89 Loss: 0.0300\n",
|
| 414 |
+
"Epoch 90 Loss: 0.0295\n",
|
| 415 |
+
"Epoch 91 Loss: 0.0422\n",
|
| 416 |
+
"Epoch 92 Loss: 0.0403\n",
|
| 417 |
+
"Epoch 93 Loss: 0.0225\n",
|
| 418 |
+
"Epoch 94 Loss: 0.0335\n",
|
| 419 |
+
"Epoch 95 Loss: 0.0457\n",
|
| 420 |
+
"Epoch 96 Loss: 0.0307\n",
|
| 421 |
+
"Epoch 97 Loss: 0.0253\n",
|
| 422 |
+
"Epoch 98 Loss: 0.0543\n",
|
| 423 |
+
"Epoch 99 Loss: 0.0302\n",
|
| 424 |
+
"Epoch 100 Loss: 0.0237\n",
|
| 425 |
+
"Epoch 101 Loss: 0.0344\n",
|
| 426 |
+
"Epoch 102 Loss: 0.0417\n",
|
| 427 |
+
"Epoch 103 Loss: 0.0227\n",
|
| 428 |
+
"Epoch 104 Loss: 0.0267\n",
|
| 429 |
+
"Epoch 105 Loss: 0.0431\n",
|
| 430 |
+
"Epoch 106 Loss: 0.0263\n",
|
| 431 |
+
"Epoch 107 Loss: 0.0442\n",
|
| 432 |
+
"Epoch 108 Loss: 0.0300\n",
|
| 433 |
+
"Epoch 109 Loss: 0.0215\n",
|
| 434 |
+
"Epoch 110 Loss: 0.0262\n",
|
| 435 |
+
"Epoch 111 Loss: 0.0485\n",
|
| 436 |
+
"Epoch 112 Loss: 0.0253\n",
|
| 437 |
+
"Epoch 113 Loss: 0.0202\n",
|
| 438 |
+
"Epoch 114 Loss: 0.0226\n",
|
| 439 |
+
"Epoch 115 Loss: 0.0355\n",
|
| 440 |
+
"Epoch 116 Loss: 0.0534\n",
|
| 441 |
+
"Epoch 117 Loss: 0.0210\n",
|
| 442 |
+
"Epoch 118 Loss: 0.0173\n",
|
| 443 |
+
"Epoch 119 Loss: 0.0315\n",
|
| 444 |
+
"Epoch 120 Loss: 0.0457\n",
|
| 445 |
+
"Epoch 121 Loss: 0.0209\n",
|
| 446 |
+
"Epoch 122 Loss: 0.0226\n",
|
| 447 |
+
"Epoch 123 Loss: 0.0325\n",
|
| 448 |
+
"Epoch 124 Loss: 0.0320\n",
|
| 449 |
+
"Epoch 125 Loss: 0.0269\n",
|
| 450 |
+
"Epoch 126 Loss: 0.0212\n",
|
| 451 |
+
"Epoch 127 Loss: 0.0213\n",
|
| 452 |
+
"Epoch 128 Loss: 0.0313\n",
|
| 453 |
+
"Epoch 129 Loss: 0.0376\n",
|
| 454 |
+
"Epoch 130 Loss: 0.0284\n",
|
| 455 |
+
"Epoch 131 Loss: 0.0177\n",
|
| 456 |
+
"Epoch 132 Loss: 0.0172\n",
|
| 457 |
+
"Epoch 133 Loss: 0.0234\n",
|
| 458 |
+
"Epoch 134 Loss: 0.0442\n",
|
| 459 |
+
"Epoch 135 Loss: 0.0222\n",
|
| 460 |
+
"Epoch 136 Loss: 0.0293\n",
|
| 461 |
+
"Epoch 137 Loss: 0.0258\n",
|
| 462 |
+
"Epoch 138 Loss: 0.0260\n",
|
| 463 |
+
"Epoch 139 Loss: 0.0220\n",
|
| 464 |
+
"Epoch 140 Loss: 0.0167\n",
|
| 465 |
+
"Epoch 141 Loss: 0.0395\n",
|
| 466 |
+
"Epoch 142 Loss: 0.0265\n",
|
| 467 |
+
"Epoch 143 Loss: 0.0179\n",
|
| 468 |
+
"Epoch 144 Loss: 0.0195\n",
|
| 469 |
+
"Epoch 145 Loss: 0.0318\n",
|
| 470 |
+
"Epoch 146 Loss: 0.0224\n",
|
| 471 |
+
"Epoch 147 Loss: 0.0160\n",
|
| 472 |
+
"Epoch 148 Loss: 0.0215\n",
|
| 473 |
+
"Epoch 149 Loss: 0.0491\n",
|
| 474 |
+
"Epoch 150 Loss: 0.0197\n",
|
| 475 |
+
"Epoch 151 Loss: 0.0203\n",
|
| 476 |
+
"Epoch 152 Loss: 0.0238\n",
|
| 477 |
+
"Epoch 153 Loss: 0.0260\n",
|
| 478 |
+
"Epoch 154 Loss: 0.0178\n",
|
| 479 |
+
"Epoch 155 Loss: 0.0156\n",
|
| 480 |
+
"Epoch 156 Loss: 0.0171\n",
|
| 481 |
+
"Epoch 157 Loss: 0.0243\n",
|
| 482 |
+
"Epoch 158 Loss: 0.0403\n",
|
| 483 |
+
"Epoch 159 Loss: 0.0180\n",
|
| 484 |
+
"Epoch 160 Loss: 0.0172\n",
|
| 485 |
+
"Epoch 161 Loss: 0.0198\n",
|
| 486 |
+
"Epoch 162 Loss: 0.0336\n",
|
| 487 |
+
"Epoch 163 Loss: 0.0222\n",
|
| 488 |
+
"Epoch 164 Loss: 0.0155\n",
|
| 489 |
+
"Epoch 165 Loss: 0.0193\n",
|
| 490 |
+
"Epoch 166 Loss: 0.0239\n",
|
| 491 |
+
"Epoch 167 Loss: 0.0183\n",
|
| 492 |
+
"Epoch 168 Loss: 0.0160\n",
|
| 493 |
+
"Epoch 169 Loss: 0.0182\n",
|
| 494 |
+
"Epoch 170 Loss: 0.0389\n",
|
| 495 |
+
"Epoch 171 Loss: 0.0229\n",
|
| 496 |
+
"Epoch 172 Loss: 0.0171\n",
|
| 497 |
+
"Epoch 173 Loss: 0.0162\n",
|
| 498 |
+
"Epoch 174 Loss: 0.0206\n",
|
| 499 |
+
"Epoch 175 Loss: 0.0159\n",
|
| 500 |
+
"Epoch 176 Loss: 0.0158\n",
|
| 501 |
+
"Epoch 177 Loss: 0.0361\n",
|
| 502 |
+
"Epoch 178 Loss: 0.0346\n",
|
| 503 |
+
"Epoch 179 Loss: 0.0183\n",
|
| 504 |
+
"Epoch 180 Loss: 0.0163\n",
|
| 505 |
+
"Epoch 181 Loss: 0.0132\n",
|
| 506 |
+
"Epoch 182 Loss: 0.0162\n",
|
| 507 |
+
"Epoch 183 Loss: 0.0160\n",
|
| 508 |
+
"Epoch 184 Loss: 0.0348\n",
|
| 509 |
+
"Epoch 185 Loss: 0.0271\n",
|
| 510 |
+
"Epoch 186 Loss: 0.0168\n",
|
| 511 |
+
"Epoch 187 Loss: 0.0129\n",
|
| 512 |
+
"Epoch 188 Loss: 0.0138\n",
|
| 513 |
+
"Epoch 189 Loss: 0.0161\n",
|
| 514 |
+
"Epoch 190 Loss: 0.0218\n",
|
| 515 |
+
"Epoch 191 Loss: 0.0254\n",
|
| 516 |
+
"Epoch 192 Loss: 0.0254\n",
|
| 517 |
+
"Epoch 193 Loss: 0.0127\n",
|
| 518 |
+
"Epoch 194 Loss: 0.0126\n",
|
| 519 |
+
"Epoch 195 Loss: 0.0151\n",
|
| 520 |
+
"Epoch 196 Loss: 0.0197\n",
|
| 521 |
+
"Epoch 197 Loss: 0.0283\n",
|
| 522 |
+
"Epoch 198 Loss: 0.0157\n",
|
| 523 |
+
"Epoch 199 Loss: 0.0134\n",
|
| 524 |
+
"Epoch 200 Loss: 0.0165\n",
|
| 525 |
+
"Done!\n"
|
| 526 |
+
]
|
| 527 |
+
}
|
| 528 |
+
],
|
| 529 |
+
"source": [
|
| 530 |
+
"# Train Loop\n",
|
| 531 |
+
"print(\"Starting Training...\")\n",
|
| 532 |
+
"epochs = 200\n",
|
| 533 |
+
"for epoch in range(epochs):\n",
|
| 534 |
+
" sbert.train()\n",
|
| 535 |
+
" total_loss = 0\n",
|
| 536 |
+
" for batch in train_loader:\n",
|
| 537 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 538 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 539 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 540 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 541 |
+
" labels = batch['label'].to(device)\n",
|
| 542 |
+
"\n",
|
| 543 |
+
" optimizer.zero_grad()\n",
|
| 544 |
+
" logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
|
| 545 |
+
" loss = criterion(logits, labels)\n",
|
| 546 |
+
" loss.backward()\n",
|
| 547 |
+
" optimizer.step()\n",
|
| 548 |
+
" total_loss += loss.item()\n",
|
| 549 |
+
" print(f\"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}\")\n",
|
| 550 |
+
"\n",
|
| 551 |
+
"print(\"Done!\")\n",
|
| 552 |
+
"torch.save(sbert.state_dict(), f'./models/sbert_{DATASET_NAME}.pt')"
|
| 553 |
+
]
|
| 554 |
+
},
|
| 555 |
+
{
|
| 556 |
+
"cell_type": "markdown",
|
| 557 |
+
"metadata": {},
|
| 558 |
+
"source": [
|
| 559 |
+
"## 4. Evaluation\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"Evaluate on specific test set."
|
| 562 |
+
]
|
| 563 |
+
},
|
| 564 |
+
{
|
| 565 |
+
"cell_type": "code",
|
| 566 |
+
"execution_count": 15,
|
| 567 |
+
"metadata": {},
|
| 568 |
+
"outputs": [
|
| 569 |
+
{
|
| 570 |
+
"name": "stdout",
|
| 571 |
+
"output_type": "stream",
|
| 572 |
+
"text": [
|
| 573 |
+
"Evaluating...\n",
|
| 574 |
+
"Classification Report:\n",
|
| 575 |
+
" precision recall f1-score support\n",
|
| 576 |
+
"\n",
|
| 577 |
+
" Entailment 0.53 0.56 0.54 339\n",
|
| 578 |
+
" Neutral 0.51 0.44 0.47 324\n",
|
| 579 |
+
"Contradiction 0.45 0.49 0.47 325\n",
|
| 580 |
+
"\n",
|
| 581 |
+
" accuracy 0.49 988\n",
|
| 582 |
+
" macro avg 0.50 0.49 0.49 988\n",
|
| 583 |
+
" weighted avg 0.50 0.49 0.49 988\n",
|
| 584 |
+
"\n",
|
| 585 |
+
"Accuracy: 0.4949\n"
|
| 586 |
+
]
|
| 587 |
+
}
|
| 588 |
+
],
|
| 589 |
+
"source": [
|
| 590 |
+
"sbert.eval()\n",
|
| 591 |
+
"all_preds = []\n",
|
| 592 |
+
"all_labels = []\n",
|
| 593 |
+
"\n",
|
| 594 |
+
"print(\"Evaluating...\")\n",
|
| 595 |
+
"with torch.no_grad():\n",
|
| 596 |
+
" for batch in test_loader:\n",
|
| 597 |
+
" p_ids = batch['premise_input_ids'].to(device)\n",
|
| 598 |
+
" p_seg = batch['premise_segment_ids'].to(device)\n",
|
| 599 |
+
" h_ids = batch['hypothesis_input_ids'].to(device)\n",
|
| 600 |
+
" h_seg = batch['hypothesis_segment_ids'].to(device)\n",
|
| 601 |
+
" labels = batch['label'].to(device)\n",
|
| 602 |
+
"\n",
|
| 603 |
+
" logits = sbert(p_ids, p_seg, h_ids, h_seg)\n",
|
| 604 |
+
" preds = torch.argmax(logits, dim=1)\n",
|
| 605 |
+
" \n",
|
| 606 |
+
" all_preds.extend(preds.cpu().numpy())\n",
|
| 607 |
+
" all_labels.extend(labels.cpu().numpy())\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"target_names = ['Entailment', 'Neutral', 'Contradiction']\n",
|
| 610 |
+
"print(\"Classification Report:\")\n",
|
| 611 |
+
"print(classification_report(all_labels, all_preds, labels=[0, 1, 2], target_names=target_names))\n",
|
| 612 |
+
"print(f\"Accuracy: {accuracy_score(all_labels, all_preds):.4f}\")"
|
| 613 |
+
]
|
| 614 |
+
}
|
| 615 |
+
],
|
| 616 |
+
"metadata": {
|
| 617 |
+
"kernelspec": {
|
| 618 |
+
"display_name": "Python 3",
|
| 619 |
+
"language": "python",
|
| 620 |
+
"name": "python3"
|
| 621 |
+
},
|
| 622 |
+
"language_info": {
|
| 623 |
+
"codemirror_mode": {
|
| 624 |
+
"name": "ipython",
|
| 625 |
+
"version": 3
|
| 626 |
+
},
|
| 627 |
+
"file_extension": ".py",
|
| 628 |
+
"mimetype": "text/x-python",
|
| 629 |
+
"name": "python",
|
| 630 |
+
"nbconvert_exporter": "python",
|
| 631 |
+
"pygments_lexer": "ipython3",
|
| 632 |
+
"version": "3.8.5"
|
| 633 |
+
}
|
| 634 |
+
},
|
| 635 |
+
"nbformat": 4,
|
| 636 |
+
"nbformat_minor": 4
|
| 637 |
+
}
|
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 6 |
+
|
| 7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 8 |
+
|
| 9 |
+
COPY . .
|
| 10 |
+
|
| 11 |
+
CMD ["gunicorn", "-b", "0.0.0.0:7860", "app.app:app"]
|
README.md
CHANGED
|
@@ -1,10 +1,87 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: NLI Text Similarity App
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 8000
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# NLI Text Similarity App (A4 Assignment)
|
| 11 |
+
|
| 12 |
+
**Name:** HTUT KO KO
|
| 13 |
+
**ID:** st126010
|
| 14 |
+
|
| 15 |
+
I implemented a Mini-BERT model from scratch and fine-tuned it as a Sentence-BERT (S-BERT) model for Natural Language Inference (NLI) tasks. This project includes a modern web application for real-time similarity analysis.
|
| 16 |
+
|
| 17 |
+
## Project Structure
|
| 18 |
+
|
| 19 |
+
- `A4_BERT.ipynb`: Task 1 - I pre-trained BERT from scratch on WikiText-103.
|
| 20 |
+
- `A4_Climate_FEVER.ipynb`: Task 2 - I fine-tuned S-BERT on the Climate-FEVER dataset.
|
| 21 |
+
- `A4_Option_SNLI.ipynb`: Alternative training notebook where I trained on the SNLI dataset.
|
| 22 |
+
- `A4_Option_MNLI.ipynb`: Focused notebook where I trained on the MNLI dataset.
|
| 23 |
+
- `app/`: My Flask web application components.
|
| 24 |
+
- `models/`: Squared model weights (`bert_trained.pt`, `sbert_climate_fever.pt`, `sbert_snli.pt`, `sbert_mnli.pt`).
|
| 25 |
+
|
| 26 |
+
## Final Results
|
| 27 |
+
|
| 28 |
+
I trained the models on three different datasets. Here are the results I achieved:
|
| 29 |
+
|
| 30 |
+
| Dataset | Epochs | Accuracy | Loss |
|
| 31 |
+
| :--- | :--- | :--- | :--- |
|
| 32 |
+
| **Climate-FEVER** | 200 | **50.5%** | 0.0000 |
|
| 33 |
+
| **SNLI** | 200 | **50.8%** | ~0.59 |
|
| 34 |
+
| **MNLI** | 100 | **41.5%** | 0.0000 |
|
| 35 |
+
|
| 36 |
+
### detailed Climate-FEVER Metrics
|
| 37 |
+
|
| 38 |
+
| Class | Precision | Recall | F1-Score |
|
| 39 |
+
| :--- | :--- | :--- | :--- |
|
| 40 |
+
| Entailment | 0.33 | 0.28 | 0.30 |
|
| 41 |
+
| Neutral | 0.62 | 0.68 | 0.65 |
|
| 42 |
+
| Contradiction | 0.40 | 0.55 | 0.46 |
|
| 43 |
+
|
| 44 |
+
## Limitations & Analysis
|
| 45 |
+
|
| 46 |
+
### 1. Vocabulary Size
|
| 47 |
+
I limited the vocabulary size to **5004** (compared to standard BERT's 30,522) to ensure the model could be trained precisely on the smaller WikiText-103 subset. While this improved convergence for this assignment, it restricts the model's ability to understand rare words outside this vocabulary.
|
| 48 |
+
|
| 49 |
+
### 2. Tokenizer Mismatch
|
| 50 |
+
A challenge I encountered was using the standard `BertTokenizer` with my custom Mini-BERT. The tokenizer produces IDs > 5004, which caused `IndexError` in the web app. I resolved this by implementing a clamping mechanism in `app.py` to map unknown tokens to the `[UNK]` ID.
|
| 51 |
+
|
| 52 |
+
### 3. Model Depth
|
| 53 |
+
I used a "Mini-BERT" configuration (`n_layers=2`, `d_model=256`) instead of the base (`n_layers=12`, `d_model=768`). This trade-off significantly reduced training time but naturally limits the model's capacity to capture complex linguistic nuances compared to the full BERT-Base.
|
| 54 |
+
|
| 55 |
+
## Demonstration
|
| 56 |
+
|
| 57 |
+

|
| 58 |
+
|
| 59 |
+
## How to Run
|
| 60 |
+
|
| 61 |
+
### 1. Setup Environment
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### 2. Run the Web App
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python app/app.py
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Access the app at `http://127.0.0.1:8000`.
|
| 74 |
+
|
| 75 |
+
## Features
|
| 76 |
+
|
| 77 |
+
- **Modern UI**: I designed a Glassmorphism theme with a dynamic background.
|
| 78 |
+
- **Multi-Model Support**: Users can select between Climate-FEVER, SNLI, or MNLI trained models.
|
| 79 |
+
- **Explainable AI**: The app displays the probability distribution for each prediction.
|
| 80 |
+
|
| 81 |
+
## References
|
| 82 |
+
|
| 83 |
+
1. **BERT / WikiText-103**: Merity, S., Xiong, C., Bradbury, J., & Socher, R. (2016). *Pointer Sentinel Mixture Models*.
|
| 84 |
+
2. **S-BERT (Sentence-BERT)**: Reimers, N., & Gurevych, I. (2019). *Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks*.
|
| 85 |
+
3. **Climate-FEVER**: Diggelmann, T., Boyd-Graber, J., Bulian, J., Ciaramita, M., & Leippold, M. (2020). *CLIMATE-FEVER: A Dataset for Verification of Real-World Climate Claims*.
|
| 86 |
+
4. **SNLI**: Bowman, S. R., Angeli, G., Potts, C., & Manning, C. D. (2015). *A large annotated corpus for learning natural language inference*.
|
| 87 |
+
5. **MNLI**: Williams, A., Nangia, N., & Bowman, S. R. (2018). *A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference*.
|
app/app.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from flask import Flask, render_template, request, jsonify
|
| 5 |
+
from transformers import BertTokenizer
|
| 6 |
+
import os
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
app = Flask(__name__)
|
| 11 |
+
|
| 12 |
+
# --- Configuration ---
|
| 13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
MODEL_PATH = "../models/sbert_climate_fever.pt" # Path relative to app/ directory execution usually
|
| 15 |
+
# But we will run from project root or handle paths carefully
|
| 16 |
+
# Let's assume running from project root: python app/app.py
|
| 17 |
+
# Then path is models/sbert_climate_fever.pt
|
| 18 |
+
MODEL_PATH_REL = "models/sbert_climate_fever.pt"
|
| 19 |
+
|
| 20 |
+
# --- Model Definitions (Must match training Code) ---
|
| 21 |
+
# Copied from A4_Solution.ipynb
|
| 22 |
+
|
| 23 |
+
n_layers = 2
|
| 24 |
+
n_heads = 4
|
| 25 |
+
d_model = 256
|
| 26 |
+
d_ff = 256 * 4
|
| 27 |
+
d_k = d_v = 64
|
| 28 |
+
n_segments = 2
|
| 29 |
+
max_len = 128
|
| 30 |
+
vocab_size = 5004 # Custom vocab size from training
|
| 31 |
+
|
| 32 |
+
class Embedding(nn.Module):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super(Embedding, self).__init__()
|
| 35 |
+
self.tok_embed = nn.Embedding(vocab_size, d_model) # token embedding
|
| 36 |
+
self.pos_embed = nn.Embedding(max_len, d_model) # position embedding
|
| 37 |
+
self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding
|
| 38 |
+
self.norm = nn.LayerNorm(d_model)
|
| 39 |
+
# Initialize weights to avoid large initial loss
|
| 40 |
+
self.tok_embed.weight.data.normal_(0, 0.1)
|
| 41 |
+
self.pos_embed.weight.data.normal_(0, 0.1)
|
| 42 |
+
self.seg_embed.weight.data.normal_(0, 0.1)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, seg):
|
| 45 |
+
seq_len = x.size(1)
|
| 46 |
+
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
| 47 |
+
pos = pos.unsqueeze(0).expand_as(x) # (len,) -> (bs, len)
|
| 48 |
+
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
|
| 49 |
+
return self.norm(embedding)
|
| 50 |
+
|
| 51 |
+
def get_attn_pad_mask(seq_q, seq_k):
|
| 52 |
+
batch_size, len_q = seq_q.size()
|
| 53 |
+
batch_size, len_k = seq_k.size()
|
| 54 |
+
# eq(zero) is PAD token
|
| 55 |
+
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
|
| 56 |
+
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
|
| 57 |
+
|
| 58 |
+
class ScaledDotProductAttention(nn.Module):
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super(ScaledDotProductAttention, self).__init__()
|
| 61 |
+
|
| 62 |
+
def forward(self, Q, K, V, attn_mask):
|
| 63 |
+
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
|
| 64 |
+
scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
|
| 65 |
+
attn = nn.Softmax(dim=-1)(scores)
|
| 66 |
+
context = torch.matmul(attn, V)
|
| 67 |
+
return context, attn
|
| 68 |
+
|
| 69 |
+
class MultiHeadAttention(nn.Module):
|
| 70 |
+
def __init__(self):
|
| 71 |
+
super(MultiHeadAttention, self).__init__()
|
| 72 |
+
self.W_Q = nn.Linear(d_model, d_k * n_heads)
|
| 73 |
+
self.W_K = nn.Linear(d_model, d_k * n_heads)
|
| 74 |
+
self.W_V = nn.Linear(d_model, d_v * n_heads)
|
| 75 |
+
self.linear = nn.Linear(n_heads * d_v, d_model) # Defined in init
|
| 76 |
+
self.layer_norm = nn.LayerNorm(d_model) # Defined in init
|
| 77 |
+
|
| 78 |
+
def forward(self, Q, K, V, attn_mask):
|
| 79 |
+
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
|
| 80 |
+
residual, batch_size = Q, Q.size(0)
|
| 81 |
+
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
|
| 82 |
+
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k]
|
| 83 |
+
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k]
|
| 84 |
+
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v]
|
| 85 |
+
|
| 86 |
+
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
|
| 87 |
+
|
| 88 |
+
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
|
| 89 |
+
context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
|
| 90 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
|
| 91 |
+
output = self.linear(context)
|
| 92 |
+
return self.layer_norm(output + residual), attn # output: [batch_size x len_q x d_model]
|
| 93 |
+
|
| 94 |
+
class PoswiseFeedForwardNet(nn.Module):
|
| 95 |
+
def __init__(self):
|
| 96 |
+
super(PoswiseFeedForwardNet, self).__init__()
|
| 97 |
+
self.fc1 = nn.Linear(d_model, d_ff)
|
| 98 |
+
self.fc2 = nn.Linear(d_ff, d_model)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
# (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
|
| 102 |
+
return self.fc2(F.gelu(self.fc1(x)))
|
| 103 |
+
|
| 104 |
+
class EncoderLayer(nn.Module):
|
| 105 |
+
def __init__(self):
|
| 106 |
+
super(EncoderLayer, self).__init__()
|
| 107 |
+
self.enc_self_attn = MultiHeadAttention()
|
| 108 |
+
self.pos_ffn = PoswiseFeedForwardNet()
|
| 109 |
+
|
| 110 |
+
def forward(self, enc_inputs, enc_self_attn_mask):
|
| 111 |
+
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
|
| 112 |
+
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
|
| 113 |
+
return enc_outputs, attn
|
| 114 |
+
|
| 115 |
+
class BERT(nn.Module):
|
| 116 |
+
def __init__(self):
|
| 117 |
+
super(BERT, self).__init__()
|
| 118 |
+
self.embedding = Embedding()
|
| 119 |
+
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
|
| 120 |
+
self.fc = nn.Linear(d_model, d_model)
|
| 121 |
+
self.activ = nn.Tanh()
|
| 122 |
+
self.linear = nn.Linear(d_model, d_model)
|
| 123 |
+
self.norm = nn.LayerNorm(d_model)
|
| 124 |
+
|
| 125 |
+
self.classifier = nn.Linear(d_model, 2)
|
| 126 |
+
# decoder is shared with embedding layer
|
| 127 |
+
embed_weight = self.embedding.tok_embed.weight
|
| 128 |
+
n_vocab, n_dim = embed_weight.size()
|
| 129 |
+
self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
|
| 130 |
+
self.decoder.weight = embed_weight
|
| 131 |
+
self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
|
| 132 |
+
|
| 133 |
+
def forward(self, input_ids, segment_ids, masked_pos=None):
|
| 134 |
+
# NOTE: masked_pos is optional here because for S-BERT we only need 'output'
|
| 135 |
+
# But to be consistent with NLI/Notebook forward pass, we handle it if provided
|
| 136 |
+
# or just run through.
|
| 137 |
+
|
| 138 |
+
output = self.embedding(input_ids, segment_ids)
|
| 139 |
+
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
|
| 140 |
+
for layer in self.layers:
|
| 141 |
+
output, enc_self_attn = layer(output, enc_self_attn_mask)
|
| 142 |
+
# output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
|
| 143 |
+
|
| 144 |
+
# 1. predict next sentence
|
| 145 |
+
# it will be decided by first token(CLS)
|
| 146 |
+
h_pooled = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
|
| 147 |
+
logits_nsp = self.classifier(h_pooled) # [batch_size, 2]
|
| 148 |
+
|
| 149 |
+
# 2. predict the masked token
|
| 150 |
+
if masked_pos is not None:
|
| 151 |
+
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
|
| 152 |
+
h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
|
| 153 |
+
h_masked = self.norm(F.gelu(self.linear(h_masked)))
|
| 154 |
+
logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]
|
| 155 |
+
return logits_lm, logits_nsp, output
|
| 156 |
+
else:
|
| 157 |
+
return None, logits_nsp, output # S-BERT inference only needs output
|
| 158 |
+
|
| 159 |
+
class SBERT(nn.Module):
|
| 160 |
+
def __init__(self, bert_model):
|
| 161 |
+
super(SBERT, self).__init__()
|
| 162 |
+
self.bert = bert_model
|
| 163 |
+
# 3 * d_model because we concat u, v, |u-v|
|
| 164 |
+
self.classifier = nn.Linear(d_model * 3, 3)
|
| 165 |
+
|
| 166 |
+
def forward(self, premise_ids, premise_seg, hypothesis_ids, hypothesis_seg):
|
| 167 |
+
# Make dummy masked_pos for BERT forward (it's not used for encoding really, but input requires it)
|
| 168 |
+
# Creating a dummy masked_pos of shape [batch_size, 1] filled with 0
|
| 169 |
+
dummy_masked_pos = torch.zeros((premise_ids.size(0), 1), dtype=torch.long).to(premise_ids.device)
|
| 170 |
+
|
| 171 |
+
# Encode Premise (u)
|
| 172 |
+
_, _, output_u = self.bert(premise_ids, premise_seg, dummy_masked_pos)
|
| 173 |
+
# Mean Pooling
|
| 174 |
+
mask_u = (premise_ids != 0).unsqueeze(-1).float() # [batch, len, 1]
|
| 175 |
+
u = torch.sum(output_u * mask_u, dim=1) / torch.clamp(mask_u.sum(dim=1), min=1e-9)
|
| 176 |
+
|
| 177 |
+
# Encode Hypothesis (v)
|
| 178 |
+
_, _, output_v = self.bert(hypothesis_ids, hypothesis_seg, dummy_masked_pos)
|
| 179 |
+
mask_v = (hypothesis_ids != 0).unsqueeze(-1).float()
|
| 180 |
+
v = torch.sum(output_v * mask_v, dim=1) / torch.clamp(mask_v.sum(dim=1), min=1e-9)
|
| 181 |
+
|
| 182 |
+
# Classifier: concatenate u, v, |u-v|
|
| 183 |
+
uv_abs = torch.abs(u - v)
|
| 184 |
+
features = torch.cat([u, v, uv_abs], dim=-1)
|
| 185 |
+
|
| 186 |
+
logits = self.classifier(features)
|
| 187 |
+
return logits, u, v # returning u, v for cosine sim later if needed
|
| 188 |
+
|
| 189 |
+
# --- Model Management ---
|
| 190 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 191 |
+
models = {}
|
| 192 |
+
MODEL_FILES = {
|
| 193 |
+
'climate_fever': 'models/sbert_climate_fever.pt',
|
| 194 |
+
'snli': 'models/sbert_snli.pt',
|
| 195 |
+
'mnli': 'models/sbert_mnli.pt'
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def get_model(model_name):
|
| 199 |
+
# Handle custom input scenario
|
| 200 |
+
if model_name == 'custom':
|
| 201 |
+
model_name = 'climate_fever'
|
| 202 |
+
|
| 203 |
+
# Load model on demand or return cached
|
| 204 |
+
if model_name in models:
|
| 205 |
+
return models[model_name]
|
| 206 |
+
|
| 207 |
+
# Check if model name is known
|
| 208 |
+
if model_name not in MODEL_FILES:
|
| 209 |
+
print(f"Warning: Unknown model name '{model_name}'.")
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
rel_path = MODEL_FILES[model_name]
|
| 213 |
+
path = f"../{rel_path}"
|
| 214 |
+
|
| 215 |
+
if not os.path.exists(path):
|
| 216 |
+
# Fallback to local path if running from app folder
|
| 217 |
+
path = rel_path
|
| 218 |
+
|
| 219 |
+
if not os.path.exists(path):
|
| 220 |
+
print(f"Model file not found at {path}")
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
print(f"Loading {model_name} from {path}...")
|
| 224 |
+
try:
|
| 225 |
+
bert = BERT()
|
| 226 |
+
model = SBERT(bert)
|
| 227 |
+
state_dict = torch.load(path, map_location=DEVICE)
|
| 228 |
+
model.load_state_dict(state_dict, strict=False)
|
| 229 |
+
model.to(DEVICE)
|
| 230 |
+
model.eval()
|
| 231 |
+
models[model_name] = model
|
| 232 |
+
return model
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"Failed to load {model_name}: {e}")
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
# Pre-load default
|
| 238 |
+
get_model('climate_fever')
|
| 239 |
+
|
| 240 |
+
@app.route('/')
|
| 241 |
+
def home():
|
| 242 |
+
return render_template('index.html')
|
| 243 |
+
|
| 244 |
+
@app.route('/predict', methods=['POST'])
|
| 245 |
+
def predict():
|
| 246 |
+
data = request.json
|
| 247 |
+
sentence1 = data.get('sentence1', '')
|
| 248 |
+
sentence2 = data.get('sentence2', '')
|
| 249 |
+
model_type = data.get('model_type', 'climate_fever') # Default
|
| 250 |
+
|
| 251 |
+
if not sentence1 or not sentence2:
|
| 252 |
+
return jsonify({'error': 'Both sentences are required'}), 400
|
| 253 |
+
|
| 254 |
+
# Get specific model
|
| 255 |
+
model = get_model(model_type)
|
| 256 |
+
|
| 257 |
+
if model is None:
|
| 258 |
+
# Fallback to whatever is loaded or error
|
| 259 |
+
if models:
|
| 260 |
+
model = list(models.values())[0]
|
| 261 |
+
print(f"Warning: Requested {model_type} not found, using fallback.")
|
| 262 |
+
else:
|
| 263 |
+
return jsonify({'error': f'Model {model_type} not trained/found. Please train it first!'}), 404
|
| 264 |
+
|
| 265 |
+
# Tokenize
|
| 266 |
+
# Tokenize
|
| 267 |
+
inputs_a = tokenizer(sentence1, max_length=128, truncation=True, padding='max_length')
|
| 268 |
+
inputs_b = tokenizer(sentence2, max_length=128, truncation=True, padding='max_length')
|
| 269 |
+
|
| 270 |
+
p_ids = torch.tensor(inputs_a['input_ids']).unsqueeze(0).to(DEVICE)
|
| 271 |
+
p_seg = torch.tensor(inputs_a['token_type_ids']).unsqueeze(0).to(DEVICE)
|
| 272 |
+
h_ids = torch.tensor(inputs_b['input_ids']).unsqueeze(0).to(DEVICE)
|
| 273 |
+
h_seg = torch.tensor(inputs_b['token_type_ids']).unsqueeze(0).to(DEVICE)
|
| 274 |
+
|
| 275 |
+
# Clamp inputs to vocab size (handle OOV from standard tokenizer)
|
| 276 |
+
p_ids[p_ids >= vocab_size] = 1 # [UNK]
|
| 277 |
+
h_ids[h_ids >= vocab_size] = 1 # [UNK]
|
| 278 |
+
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
logits, u, v = model(p_ids, p_seg, h_ids, h_seg)
|
| 281 |
+
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| 282 |
+
|
| 283 |
+
# Labels: entailment, neutral, contradiction
|
| 284 |
+
# Note: SNLI/MNLI/Climate-FEVER generally follow Entailment(0), Neutral(1), Contradiction(2)
|
| 285 |
+
# BUT check the mapping in notebooks.
|
| 286 |
+
# Climate-Fever: 0:Supports(Entailment), 1:Refutes(Contradiction), 2:NEI(Neutral) -> Re-mapped in NB to 0, 2, 1?
|
| 287 |
+
# Let's check training NB mapping.
|
| 288 |
+
# In A4_Climate_FEVER.ipynb: label_map = {0: 0, 1: 2, 2: 1} -> 0:Entailment, 2:Contradiction, 1:Neutral
|
| 289 |
+
# In standard SNLI/MNLI: 0:Entailment, 1:Neutral, 2:Contradiction
|
| 290 |
+
|
| 291 |
+
# We need to map probs correctly based on model_type
|
| 292 |
+
if model_type == 'climate_fever':
|
| 293 |
+
# Trained with: 0=Entailment, 1=Neutral, 2=Contradiction (Based on my previous fix? Wait, check mapping in NB)
|
| 294 |
+
# NB: label_map = {0: 0, 1: 2, 2: 1} -> This means original 0->0, 1->2, 2->1.
|
| 295 |
+
# So Model Outputs: Class 0=Entailment, Class 1=Neutral, Class 2=Contradiction
|
| 296 |
+
# Wait, if map is {0:0, 1:2, 2:1}, then Evidence Label 0 (Supports) -> Class 0
|
| 297 |
+
# Evidence Label 1 (Refutes) -> Class 2
|
| 298 |
+
# Evidence Label 2 (NEI) -> Class 1
|
| 299 |
+
# So Class indices: 0=Entailment, 1=Neutral, 2=Contradiction.
|
| 300 |
+
labels = ['Entailment', 'Neutral', 'Contradiction']
|
| 301 |
+
else:
|
| 302 |
+
# SNLI/MNLI standard: 0=Entailment, 1=Neutral, 2=Contradiction
|
| 303 |
+
labels = ['Entailment', 'Neutral', 'Contradiction']
|
| 304 |
+
|
| 305 |
+
# Result dict
|
| 306 |
+
result = {label: float(prob) for label, prob in zip(labels, probs)}
|
| 307 |
+
prediction = labels[np.argmax(probs)]
|
| 308 |
+
|
| 309 |
+
return jsonify({
|
| 310 |
+
'prediction': prediction,
|
| 311 |
+
'probabilities': result,
|
| 312 |
+
'used_model': model_type
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
if __name__ == '__main__':
|
| 316 |
+
app.run(debug=True, port=8000)
|
app/static/style.css
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--primary: #6366f1;
|
| 3 |
+
--primary-hover: #4f46e5;
|
| 4 |
+
--bg-color: #0f172a;
|
| 5 |
+
--card-bg: rgba(30, 41, 59, 0.7);
|
| 6 |
+
--text-color: #f8fafc;
|
| 7 |
+
--text-muted: #94a3b8;
|
| 8 |
+
--border-color: rgba(255, 255, 255, 0.1);
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
* {
|
| 12 |
+
box-sizing: border-box;
|
| 13 |
+
margin: 0;
|
| 14 |
+
padding: 0;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
body {
|
| 18 |
+
font-family: 'Inter', sans-serif;
|
| 19 |
+
background-color: var(--bg-color);
|
| 20 |
+
color: var(--text-color);
|
| 21 |
+
min-height: 100vh;
|
| 22 |
+
display: flex;
|
| 23 |
+
justify-content: center;
|
| 24 |
+
align-items: center;
|
| 25 |
+
overflow-x: hidden;
|
| 26 |
+
position: relative;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
/* Ambient Background Effect */
|
| 30 |
+
.background-orb {
|
| 31 |
+
position: fixed;
|
| 32 |
+
top: -20%;
|
| 33 |
+
left: -10%;
|
| 34 |
+
width: 50vw;
|
| 35 |
+
height: 50vw;
|
| 36 |
+
background: radial-gradient(circle, rgba(99, 102, 241, 0.3) 0%, rgba(15, 23, 42, 0) 70%);
|
| 37 |
+
border-radius: 50%;
|
| 38 |
+
z-index: -1;
|
| 39 |
+
animation: float 10s infinite ease-in-out;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
@keyframes float {
|
| 43 |
+
0%, 100% { transform: translate(0, 0); }
|
| 44 |
+
50% { transform: translate(20px, 30px); }
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
.container {
|
| 48 |
+
width: 100%;
|
| 49 |
+
max-width: 800px;
|
| 50 |
+
padding: 2rem;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
header {
|
| 54 |
+
text-align: center;
|
| 55 |
+
margin-bottom: 3rem;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
header h1 {
|
| 59 |
+
font-size: 2.5rem;
|
| 60 |
+
font-weight: 700;
|
| 61 |
+
background: linear-gradient(135deg, #818cf8, #c084fc);
|
| 62 |
+
-webkit-background-clip: text;
|
| 63 |
+
-webkit-text-fill-color: transparent;
|
| 64 |
+
margin-bottom: 0.5rem;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
.subtitle {
|
| 68 |
+
color: var(--text-muted);
|
| 69 |
+
font-size: 1.1rem;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
.explanation-box {
|
| 73 |
+
margin-top: 1.5rem;
|
| 74 |
+
background-color: rgba(51, 65, 85, 0.5);
|
| 75 |
+
border: 1px solid var(--border-color);
|
| 76 |
+
padding: 1rem;
|
| 77 |
+
border-radius: 0.75rem;
|
| 78 |
+
text-align: left;
|
| 79 |
+
font-size: 0.9rem;
|
| 80 |
+
color: var(--text-muted);
|
| 81 |
+
}
|
| 82 |
+
.explanation-box h3 {
|
| 83 |
+
color: var(--text-color);
|
| 84 |
+
margin-bottom: 0.5rem;
|
| 85 |
+
font-size: 1rem;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
main {
|
| 89 |
+
background: var(--card-bg);
|
| 90 |
+
backdrop-filter: blur(12px);
|
| 91 |
+
-webkit-backdrop-filter: blur(12px);
|
| 92 |
+
border: 1px solid var(--border-color);
|
| 93 |
+
border-radius: 1.5rem;
|
| 94 |
+
padding: 2rem;
|
| 95 |
+
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.5);
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
.control-panel {
|
| 99 |
+
margin-bottom: 1.5rem;
|
| 100 |
+
}
|
| 101 |
+
.control-panel label {
|
| 102 |
+
display: block;
|
| 103 |
+
margin-bottom: 0.5rem;
|
| 104 |
+
color: var(--text-muted);
|
| 105 |
+
font-size: 0.9rem;
|
| 106 |
+
}
|
| 107 |
+
.control-panel select {
|
| 108 |
+
width: 100%;
|
| 109 |
+
padding: 0.75rem;
|
| 110 |
+
border-radius: 0.5rem;
|
| 111 |
+
border: 1px solid var(--border-color);
|
| 112 |
+
background-color: rgba(15, 23, 42, 0.5);
|
| 113 |
+
color: var(--text-color);
|
| 114 |
+
font-size: 1rem;
|
| 115 |
+
outline: none;
|
| 116 |
+
cursor: pointer;
|
| 117 |
+
transition: border-color 0.2s;
|
| 118 |
+
}
|
| 119 |
+
.control-panel select:focus {
|
| 120 |
+
border-color: var(--primary);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
.input-group {
|
| 124 |
+
display: grid;
|
| 125 |
+
gap: 1.5rem;
|
| 126 |
+
margin-bottom: 2rem;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
.input-card label {
|
| 130 |
+
display: block;
|
| 131 |
+
margin-bottom: 0.5rem;
|
| 132 |
+
font-weight: 600;
|
| 133 |
+
color: var(--text-muted);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.input-wrapper input {
|
| 137 |
+
width: 100%;
|
| 138 |
+
padding: 1rem;
|
| 139 |
+
border-radius: 0.75rem;
|
| 140 |
+
border: 1px solid var(--border-color);
|
| 141 |
+
background-color: rgba(15, 23, 42, 0.5);
|
| 142 |
+
color: var(--text-color);
|
| 143 |
+
font-size: 1rem;
|
| 144 |
+
transition: all 0.2s;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.input-wrapper input:focus {
|
| 148 |
+
outline: none;
|
| 149 |
+
border-color: var(--primary);
|
| 150 |
+
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
button#analyze-btn {
|
| 154 |
+
width: 100%;
|
| 155 |
+
padding: 1rem;
|
| 156 |
+
border: none;
|
| 157 |
+
border-radius: 0.75rem;
|
| 158 |
+
background-color: var(--primary);
|
| 159 |
+
color: white;
|
| 160 |
+
font-size: 1.1rem;
|
| 161 |
+
font-weight: 600;
|
| 162 |
+
cursor: pointer;
|
| 163 |
+
transition: background-color 0.2s, transform 0.1s;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
button#analyze-btn:hover {
|
| 167 |
+
background-color: var(--primary-hover);
|
| 168 |
+
}
|
| 169 |
+
button#analyze-btn:active {
|
| 170 |
+
transform: scale(0.98);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
.result-card {
|
| 174 |
+
margin-top: 2rem;
|
| 175 |
+
padding-top: 2rem;
|
| 176 |
+
border-top: 1px solid var(--border-color);
|
| 177 |
+
animation: slideUp 0.3s ease-out;
|
| 178 |
+
}
|
| 179 |
+
.hidden {
|
| 180 |
+
display: none;
|
| 181 |
+
}
|
| 182 |
+
@keyframes slideUp {
|
| 183 |
+
from { opacity: 0; transform: translateY(10px); }
|
| 184 |
+
to { opacity: 1; transform: translateY(0); }
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.prediction-header {
|
| 188 |
+
text-align: center;
|
| 189 |
+
font-size: 1.5rem;
|
| 190 |
+
font-weight: 700;
|
| 191 |
+
margin-bottom: 1.5rem;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
.prob-bar {
|
| 195 |
+
display: flex;
|
| 196 |
+
align-items: center;
|
| 197 |
+
margin-bottom: 0.75rem;
|
| 198 |
+
gap: 1rem;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
.prob-label {
|
| 202 |
+
width: 100px;
|
| 203 |
+
font-size: 0.9rem;
|
| 204 |
+
text-align: right;
|
| 205 |
+
color: var(--text-muted);
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
.bar-container {
|
| 209 |
+
flex-grow: 1;
|
| 210 |
+
height: 10px;
|
| 211 |
+
background-color: rgba(255, 255, 255, 0.1);
|
| 212 |
+
border-radius: 10px;
|
| 213 |
+
overflow: hidden;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
.bar {
|
| 217 |
+
height: 100%;
|
| 218 |
+
background-color: var(--primary);
|
| 219 |
+
border-radius: 10px;
|
| 220 |
+
transition: width 0.5s ease-out;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.prob-val {
|
| 224 |
+
width: 50px;
|
| 225 |
+
font-size: 0.9rem;
|
| 226 |
+
font-weight: 600;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
footer {
|
| 230 |
+
text-align: center;
|
| 231 |
+
margin-top: 3rem;
|
| 232 |
+
color: var(--text-muted);
|
| 233 |
+
font-size: 0.9rem;
|
| 234 |
+
}
|
app/templates/index.html
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>S-BERT Semantic Similarity Analysis</title>
|
| 7 |
+
<link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
|
| 8 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap" rel="stylesheet">
|
| 9 |
+
</head>
|
| 10 |
+
<body>
|
| 11 |
+
<div class="background-orb"></div>
|
| 12 |
+
<div class="container">
|
| 13 |
+
<header>
|
| 14 |
+
<h1>Semantic Textual Similarity</h1>
|
| 15 |
+
<p class="subtitle">Analyze the relationship between Scientific Claims and Evidence</p>
|
| 16 |
+
|
| 17 |
+
<div class="explanation-box">
|
| 18 |
+
<h3>What is this?</h3>
|
| 19 |
+
<p>
|
| 20 |
+
This tool uses a <strong>Sentence-BERT (S-BERT)</strong> model trained on the <em>Climate-FEVER</em> dataset
|
| 21 |
+
to determine if a Piece of <strong>Evidence</strong> supports, contradicts, or is neutral towards a specific <strong>Claim</strong>.
|
| 22 |
+
Select a dataset scenario below to populate examples, or type your own sentences to test the model's understanding of semantic logic.
|
| 23 |
+
</p>
|
| 24 |
+
</div>
|
| 25 |
+
</header>
|
| 26 |
+
|
| 27 |
+
<main>
|
| 28 |
+
<div class="control-panel">
|
| 29 |
+
<label for="dataset-select">Choose a Dataset / Scenario:</label>
|
| 30 |
+
<select id="dataset-select" onchange="loadScenario()">
|
| 31 |
+
<option value="climate_fever" selected>Climate-FEVER (Science/Climate Only)</option>
|
| 32 |
+
<option value="snli">SNLI (General Knowledge / Logic)</option>
|
| 33 |
+
<option value="mnli">MNLI (Multi-Genre)</option>
|
| 34 |
+
<option value="custom">Custom Input</option>
|
| 35 |
+
</select>
|
| 36 |
+
<p id="scenario-hint" style="font-size: 0.8rem; color: #94a3b8; margin-top: 0.2rem;">
|
| 37 |
+
<em>Tip: 'The Earth is flat' is general logic -> Use <strong>SNLI</strong>. Climate-FEVER is for climate-specific claims.</em>
|
| 38 |
+
</p>
|
| 39 |
+
</div>
|
| 40 |
+
|
| 41 |
+
<div class="input-group">
|
| 42 |
+
<div class="input-card">
|
| 43 |
+
<label for="sentence1">Claim / Sentence 1</label>
|
| 44 |
+
<div class="input-wrapper">
|
| 45 |
+
<input type="text" id="sentence1" list="claims-list" placeholder="Enter a claim or sentence...">
|
| 46 |
+
<datalist id="claims-list">
|
| 47 |
+
<!-- Populated by JS -->
|
| 48 |
+
</datalist>
|
| 49 |
+
</div>
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
<div class="input-card">
|
| 53 |
+
<label for="sentence2">Evidence / Sentence 2</label>
|
| 54 |
+
<div class="input-wrapper">
|
| 55 |
+
<input type="text" id="sentence2" list="evidence-list" placeholder="Enter evidence or sentence...">
|
| 56 |
+
<datalist id="evidence-list">
|
| 57 |
+
<!-- Populated by JS -->
|
| 58 |
+
</datalist>
|
| 59 |
+
</div>
|
| 60 |
+
</div>
|
| 61 |
+
</div>
|
| 62 |
+
|
| 63 |
+
<button id="analyze-btn" onclick="predict()">Analyze Similarity</button>
|
| 64 |
+
|
| 65 |
+
<div id="result" class="result-card hidden">
|
| 66 |
+
<div class="prediction-header">
|
| 67 |
+
<span id="prediction-label">Entailment</span>
|
| 68 |
+
</div>
|
| 69 |
+
<div class="probabilities">
|
| 70 |
+
<div class="prob-bar">
|
| 71 |
+
<span class="prob-label">Entailment</span>
|
| 72 |
+
<div class="bar-container"><div class="bar" id="bar-entailment" style="width: 0%"></div></div>
|
| 73 |
+
<span class="prob-val" id="val-entailment">0%</span>
|
| 74 |
+
</div>
|
| 75 |
+
<div class="prob-bar">
|
| 76 |
+
<span class="prob-label">Neutral</span>
|
| 77 |
+
<div class="bar-container"><div class="bar" id="bar-neutral" style="width: 0%"></div></div>
|
| 78 |
+
<span class="prob-val" id="val-neutral">0%</span>
|
| 79 |
+
</div>
|
| 80 |
+
<div class="prob-bar">
|
| 81 |
+
<span class="prob-label">Contradiction</span>
|
| 82 |
+
<div class="bar-container"><div class="bar" id="bar-contradiction" style="width: 0%"></div></div>
|
| 83 |
+
<span class="prob-val" id="val-contradiction">0%</span>
|
| 84 |
+
</div>
|
| 85 |
+
</div>
|
| 86 |
+
</div>
|
| 87 |
+
</main>
|
| 88 |
+
|
| 89 |
+
<footer>
|
| 90 |
+
<p>Developed by <strong>Htut Ko Ko (st126010)</strong> | A4 Assignment</p>
|
| 91 |
+
</footer>
|
| 92 |
+
</div>
|
| 93 |
+
|
| 94 |
+
<script>
|
| 95 |
+
const scenarios = {
|
| 96 |
+
'climate_fever': {
|
| 97 |
+
claims: [
|
| 98 |
+
"Global warming is caused by human activities.",
|
| 99 |
+
"Sea levels are rising due to melting ice caps.",
|
| 100 |
+
"The sun is the primary driver of recent climate change."
|
| 101 |
+
],
|
| 102 |
+
evidence: [
|
| 103 |
+
"The IPCC report confirms that human influence has warmed the atmosphere, ocean and land.",
|
| 104 |
+
"Satellite data shows a steady increase in global sea levels over the past century.",
|
| 105 |
+
"Solar irradiance has remained relatively stable while temperatures have soared."
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
'snli': {
|
| 109 |
+
claims: [
|
| 110 |
+
"A soccer player is running across the field.",
|
| 111 |
+
"A person is inspecting the tires of a bicycle.",
|
| 112 |
+
"Two men are playing basketball."
|
| 113 |
+
],
|
| 114 |
+
evidence: [
|
| 115 |
+
"A person is moving fast on a grass surface.",
|
| 116 |
+
"A mechanic is fixing a car.",
|
| 117 |
+
"The men are playing a sport."
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
'mnli': {
|
| 121 |
+
claims: [
|
| 122 |
+
"The government announced a new tax policy.",
|
| 123 |
+
"He turned and looked at the woman.",
|
| 124 |
+
"The concert was cancelled due to rain."
|
| 125 |
+
],
|
| 126 |
+
evidence: [
|
| 127 |
+
"New financial regulations were introduced by the state.",
|
| 128 |
+
"He ignored the person standing next to him.",
|
| 129 |
+
"The outdoor event proceeded despite the bad weather."
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
function loadScenario() {
|
| 135 |
+
const select = document.getElementById('dataset-select');
|
| 136 |
+
const scenarioKey = select.value;
|
| 137 |
+
const claimsList = document.getElementById('claims-list');
|
| 138 |
+
const evidenceList = document.getElementById('evidence-list');
|
| 139 |
+
const s1Input = document.getElementById('sentence1');
|
| 140 |
+
const s2Input = document.getElementById('sentence2');
|
| 141 |
+
|
| 142 |
+
// Clear lists
|
| 143 |
+
claimsList.innerHTML = '';
|
| 144 |
+
evidenceList.innerHTML = '';
|
| 145 |
+
|
| 146 |
+
if (scenarioKey === 'custom') {
|
| 147 |
+
s1Input.value = '';
|
| 148 |
+
s2Input.value = '';
|
| 149 |
+
return;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
const data = scenarios[scenarioKey];
|
| 153 |
+
|
| 154 |
+
// Populate Datalists
|
| 155 |
+
data.claims.forEach(item => {
|
| 156 |
+
const opt = document.createElement('option');
|
| 157 |
+
opt.value = item;
|
| 158 |
+
claimsList.appendChild(opt);
|
| 159 |
+
});
|
| 160 |
+
data.evidence.forEach(item => {
|
| 161 |
+
const opt = document.createElement('option');
|
| 162 |
+
opt.value = item;
|
| 163 |
+
evidenceList.appendChild(opt);
|
| 164 |
+
});
|
| 165 |
+
|
| 166 |
+
// Auto-fill first example for convenience
|
| 167 |
+
s1Input.value = data.claims[0];
|
| 168 |
+
s2Input.value = data.evidence[0];
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
async function predict() {
|
| 172 |
+
const s1 = document.getElementById('sentence1').value;
|
| 173 |
+
const s2 = document.getElementById('sentence2').value;
|
| 174 |
+
const modelType = document.getElementById('dataset-select').value; // Get selected model
|
| 175 |
+
const resultDiv = document.getElementById('result');
|
| 176 |
+
const btn = document.getElementById('analyze-btn');
|
| 177 |
+
|
| 178 |
+
if (!s1 || !s2) {
|
| 179 |
+
alert("Please enter both sentences.");
|
| 180 |
+
return;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
btn.textContent = "Analyzing...";
|
| 184 |
+
resultDiv.classList.add('hidden');
|
| 185 |
+
|
| 186 |
+
try {
|
| 187 |
+
const response = await fetch('/predict', {
|
| 188 |
+
method: 'POST',
|
| 189 |
+
headers: { 'Content-Type': 'application/json' },
|
| 190 |
+
body: JSON.stringify({
|
| 191 |
+
sentence1: s1,
|
| 192 |
+
sentence2: s2,
|
| 193 |
+
model_type: modelType // Send model type
|
| 194 |
+
})
|
| 195 |
+
});
|
| 196 |
+
|
| 197 |
+
const data = await response.json();
|
| 198 |
+
|
| 199 |
+
if (data.error) {
|
| 200 |
+
alert(data.error);
|
| 201 |
+
return;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// Update UI
|
| 205 |
+
const label = document.getElementById('prediction-label');
|
| 206 |
+
label.textContent = data.prediction;
|
| 207 |
+
|
| 208 |
+
// Color coding
|
| 209 |
+
if (data.prediction === 'Entailment') label.style.color = '#10B981'; // Green
|
| 210 |
+
else if (data.prediction === 'Contradiction') label.style.color = '#EF4444'; // Red
|
| 211 |
+
else label.style.color = '#F59E0B'; // Yellow/Orange
|
| 212 |
+
|
| 213 |
+
// Update Bars
|
| 214 |
+
document.getElementById('bar-entailment').style.width = (data.probabilities.Entailment * 100) + '%';
|
| 215 |
+
document.getElementById('val-entailment').textContent = (data.probabilities.Entailment * 100).toFixed(1) + '%';
|
| 216 |
+
|
| 217 |
+
document.getElementById('bar-neutral').style.width = (data.probabilities.Neutral * 100) + '%';
|
| 218 |
+
document.getElementById('val-neutral').textContent = (data.probabilities.Neutral * 100).toFixed(1) + '%';
|
| 219 |
+
|
| 220 |
+
document.getElementById('bar-contradiction').style.width = (data.probabilities.Contradiction * 100) + '%';
|
| 221 |
+
document.getElementById('val-contradiction').textContent = (data.probabilities.Contradiction * 100).toFixed(1) + '%';
|
| 222 |
+
|
| 223 |
+
resultDiv.classList.remove('hidden');
|
| 224 |
+
|
| 225 |
+
} catch (e) {
|
| 226 |
+
console.error(e);
|
| 227 |
+
alert("Error connecting to the server.");
|
| 228 |
+
} finally {
|
| 229 |
+
btn.textContent = "Analyze Similarity";
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
// Initialize on load
|
| 234 |
+
window.onload = loadScenario;
|
| 235 |
+
</script>
|
| 236 |
+
</body>
|
| 237 |
+
</html>
|
demo.gif
ADDED
|
Git LFS Details
|
models/bert_trained.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4bd87012116d7955bc8b4fddbe591810bc70692ed1df20d3394e58d08bdbc58a
|
| 3 |
+
size 12138238
|
models/sbert_climate_fever.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfd76cef3670cbd791bc83a0bf7292cb6975ae8b190ffb02e8feb92b350e33bb
|
| 3 |
+
size 12148818
|
models/sbert_mnli.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a5397be33fe31a270905838e93ddf94a05c2f1e54f09c13c5a453b45f580fc5
|
| 3 |
+
size 12148386
|
models/sbert_snli.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7c3557c784c924a487f55a6de15d8c127599aca0cfa74a1e6655b9af7b063c2
|
| 3 |
+
size 12148386
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers
|
| 3 |
+
datasets
|
| 4 |
+
scikit-learn
|
| 5 |
+
pandas
|
| 6 |
+
numpy
|
| 7 |
+
flask
|
| 8 |
+
gunicorn
|
| 9 |
+
tqdm
|