IsmatS commited on
Commit
a119f50
·
1 Parent(s): 5117cea

building NER model from scratch

Browse files
Files changed (1) hide show
  1. models/NER_from_scratch.ipynb +438 -0
models/NER_from_scratch.ipynb ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "gpuType": "A100"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "code",
22
+ "source": [],
23
+ "metadata": {
24
+ "id": "62TB1_OCUVfz"
25
+ },
26
+ "execution_count": null,
27
+ "outputs": []
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "source": [],
32
+ "metadata": {
33
+ "id": "i0hQIwu8UVc0"
34
+ },
35
+ "execution_count": null,
36
+ "outputs": []
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "source": [
41
+ "# Sample Azerbaijani sentences with entity labels (PERSON, LOCATION, ORGANIZATION)\n",
42
+ "sentences = [\n",
43
+ " [\"İlham\", \"Əliyev\", \"Bakıda\", \"BMT-nin\", \"konfransında\", \"iştirak\", \"etdi\"],\n",
44
+ " [\"Leyla\", \"Gəncə\", \"şəhərində\", \"Azərsun\", \"şirkətində\", \"işləyir\"],\n",
45
+ " [\"Rəşad\", \"Sumqayıt\", \"şəhərinə\", \"səyahət\", \"etdi\"],\n",
46
+ " [\"Nigar\", \"və\", \"Zaur\", \"İstanbulda\", \"Türk Hava Yolları\", \"ofisində\", \"görüşdülər\"],\n",
47
+ " [\"Samir\", \"Bakıda\", \"BP\", \"şirkətinə\", \"işə\", \"daxil\", \"oldu\"]\n",
48
+ "]\n",
49
+ "\n",
50
+ "labels = [\n",
51
+ " [\"B-PERSON\", \"I-PERSON\", \"B-LOCATION\", \"B-ORGANIZATION\", \"O\", \"O\", \"O\"],\n",
52
+ " [\"B-PERSON\", \"B-LOCATION\", \"O\", \"B-ORGANIZATION\", \"O\", \"O\"],\n",
53
+ " [\"B-PERSON\", \"B-LOCATION\", \"O\", \"O\", \"O\"],\n",
54
+ " [\"B-PERSON\", \"O\", \"B-PERSON\", \"B-LOCATION\", \"B-ORGANIZATION\", \"O\", \"O\"],\n",
55
+ " [\"B-PERSON\", \"B-LOCATION\", \"B-ORGANIZATION\", \"O\", \"O\", \"O\", \"O\"]\n",
56
+ "]\n",
57
+ "\n",
58
+ "# Create vocabulary and label mappings\n",
59
+ "all_words = [word for sentence in sentences for word in sentence]\n",
60
+ "unique_words = set(all_words)\n",
61
+ "word_to_idx = {word: idx for idx, word in enumerate(unique_words, 1)}\n",
62
+ "word_to_idx[\"<UNK>\"] = 0 # Unknown token\n",
63
+ "\n",
64
+ "# Map labels to integers\n",
65
+ "label_to_idx = {\"B-PERSON\": 0, \"I-PERSON\": 1, \"B-LOCATION\": 2, \"B-ORGANIZATION\": 3, \"O\": 4}\n",
66
+ "idx_to_label = {idx: label for label, idx in label_to_idx.items()}\n"
67
+ ],
68
+ "metadata": {
69
+ "id": "RoZCdhnaTryk"
70
+ },
71
+ "execution_count": 1,
72
+ "outputs": []
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "source": [
77
+ "from sklearn.model_selection import train_test_split\n",
78
+ "\n",
79
+ "# Split data into training and validation sets (80% train, 20% validation)\n",
80
+ "train_sentences, val_sentences, train_labels, val_labels = train_test_split(\n",
81
+ " sentences, labels, test_size=0.2, random_state=42\n",
82
+ ")\n"
83
+ ],
84
+ "metadata": {
85
+ "id": "WrpBPRFvTrvs"
86
+ },
87
+ "execution_count": 2,
88
+ "outputs": []
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "source": [
93
+ "import torch\n",
94
+ "from torch.utils.data import Dataset, DataLoader\n",
95
+ "from torch.nn.utils.rnn import pad_sequence\n",
96
+ "\n",
97
+ "class NERDataset(Dataset):\n",
98
+ " def __init__(self, sentences, labels, word_to_idx, label_to_idx):\n",
99
+ " self.sentences = sentences\n",
100
+ " self.labels = labels\n",
101
+ " self.word_to_idx = word_to_idx\n",
102
+ " self.label_to_idx = label_to_idx\n",
103
+ "\n",
104
+ " def __len__(self):\n",
105
+ " return len(self.sentences)\n",
106
+ "\n",
107
+ " def __getitem__(self, idx):\n",
108
+ " words = self.sentences[idx]\n",
109
+ " tags = self.labels[idx]\n",
110
+ "\n",
111
+ " word_idxs = [self.word_to_idx.get(word, self.word_to_idx[\"<UNK>\"]) for word in words]\n",
112
+ " tag_idxs = [self.label_to_idx[tag] for tag in tags]\n",
113
+ "\n",
114
+ " return torch.tensor(word_idxs, dtype=torch.long), torch.tensor(tag_idxs, dtype=torch.long)\n",
115
+ "\n",
116
+ "def pad_collate(batch):\n",
117
+ " (sentences, labels) = zip(*batch)\n",
118
+ " sentences_padded = pad_sequence(sentences, batch_first=True, padding_value=word_to_idx[\"<UNK>\"])\n",
119
+ " labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100) # -100 for ignored tokens\n",
120
+ " return sentences_padded, labels_padded\n",
121
+ "\n",
122
+ "# Create DataLoader instances for train and validation\n",
123
+ "train_dataset = NERDataset(train_sentences, train_labels, word_to_idx, label_to_idx)\n",
124
+ "val_dataset = NERDataset(val_sentences, val_labels, word_to_idx, label_to_idx)\n",
125
+ "\n",
126
+ "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=pad_collate)\n",
127
+ "val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=pad_collate)\n"
128
+ ],
129
+ "metadata": {
130
+ "id": "KFbd0e77gpEh"
131
+ },
132
+ "execution_count": 3,
133
+ "outputs": []
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "source": [
138
+ "import torch.nn as nn\n",
139
+ "\n",
140
+ "class BiLSTM_NER(nn.Module):\n",
141
+ " def __init__(self, vocab_size, tagset_size, embedding_dim=64, hidden_dim=128):\n",
142
+ " super(BiLSTM_NER, self).__init__()\n",
143
+ " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
144
+ " self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True, batch_first=True)\n",
145
+ " self.hidden2tag = nn.Linear(hidden_dim, tagset_size)\n",
146
+ "\n",
147
+ " def forward(self, sentence):\n",
148
+ " embeds = self.embedding(sentence)\n",
149
+ " lstm_out, _ = self.lstm(embeds)\n",
150
+ " tag_space = self.hidden2tag(lstm_out)\n",
151
+ " tag_scores = torch.log_softmax(tag_space, dim=2)\n",
152
+ " return tag_scores\n"
153
+ ],
154
+ "metadata": {
155
+ "id": "i096tTXPgpB5"
156
+ },
157
+ "execution_count": 4,
158
+ "outputs": []
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "source": [
163
+ "import pandas as pd\n",
164
+ "from sklearn.metrics import classification_report\n",
165
+ "\n",
166
+ "def train_model(model, train_loader, val_loader, num_epochs=10):\n",
167
+ " # Initialize lists to collect metrics for each epoch\n",
168
+ " epoch_list, loss_list, precision_list, recall_list, f1_list = [], [], [], [], []\n",
169
+ "\n",
170
+ " loss_function = nn.CrossEntropyLoss(ignore_index=-100) # Ignore padding label (-100)\n",
171
+ " optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
172
+ "\n",
173
+ " # Training loop with metric tracking\n",
174
+ " for epoch in range(1, num_epochs + 1):\n",
175
+ " model.train() # Set model to training mode\n",
176
+ " total_loss = 0\n",
177
+ "\n",
178
+ " # Training phase\n",
179
+ " for sentence, tags in train_loader:\n",
180
+ " model.zero_grad()\n",
181
+ " tag_scores = model(sentence)\n",
182
+ "\n",
183
+ " # Reshape to match dimensions required by CrossEntropyLoss\n",
184
+ " tag_scores = tag_scores.view(-1, tag_scores.shape[-1])\n",
185
+ " tags = tags.view(-1)\n",
186
+ "\n",
187
+ " loss = loss_function(tag_scores, tags)\n",
188
+ " loss.backward()\n",
189
+ " optimizer.step()\n",
190
+ " total_loss += loss.item()\n",
191
+ "\n",
192
+ " avg_loss = total_loss / len(train_loader)\n",
193
+ "\n",
194
+ " # Evaluation phase\n",
195
+ " true_labels, predicted_labels = evaluate_model(model, val_loader, idx_to_label)\n",
196
+ " report = classification_report(true_labels, predicted_labels, labels=list(label_to_idx.keys()), zero_division=0, output_dict=True)\n",
197
+ "\n",
198
+ " # Retrieve metrics\n",
199
+ " precision = report['weighted avg']['precision']\n",
200
+ " recall = report['weighted avg']['recall']\n",
201
+ " f1_score = report['weighted avg']['f1-score']\n",
202
+ "\n",
203
+ " # Append metrics to lists\n",
204
+ " epoch_list.append(f\"Epoch {epoch}/{num_epochs}\")\n",
205
+ " loss_list.append(avg_loss)\n",
206
+ " precision_list.append(precision)\n",
207
+ " recall_list.append(recall)\n",
208
+ " f1_list.append(f1_score)\n",
209
+ "\n",
210
+ " # Create a DataFrame with the collected metrics\n",
211
+ " df = pd.DataFrame({\n",
212
+ " \"Epoch\": epoch_list,\n",
213
+ " \"Loss\": loss_list,\n",
214
+ " \"Precision\": precision_list,\n",
215
+ " \"Recall\": recall_list,\n",
216
+ " \"F1-score\": f1_list\n",
217
+ " })\n",
218
+ "\n",
219
+ " # Display the DataFrame\n",
220
+ " print(\"\\nTraining Progress\")\n",
221
+ " print(df.to_string(index=False))\n",
222
+ " return df\n"
223
+ ],
224
+ "metadata": {
225
+ "id": "cB2Qsvv0go-9"
226
+ },
227
+ "execution_count": 5,
228
+ "outputs": []
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "source": [
233
+ "def evaluate_model(model, data_loader, idx_to_label):\n",
234
+ " all_predictions = []\n",
235
+ " all_true_labels = []\n",
236
+ "\n",
237
+ " model.eval() # Set the model to evaluation mode\n",
238
+ " with torch.no_grad():\n",
239
+ " for sentences, labels in data_loader:\n",
240
+ " # Make predictions\n",
241
+ " tag_scores = model(sentences)\n",
242
+ " predictions = torch.argmax(tag_scores, dim=2)\n",
243
+ "\n",
244
+ " for pred, true in zip(predictions, labels):\n",
245
+ " pred = pred.cpu().numpy()\n",
246
+ " true = true.cpu().numpy()\n",
247
+ "\n",
248
+ " # Remove padding (-100) for accurate evaluation\n",
249
+ " true = [t for t in true if t != -100]\n",
250
+ " pred = pred[:len(true)]\n",
251
+ "\n",
252
+ " all_predictions.extend([idx_to_label[p] for p in pred])\n",
253
+ " all_true_labels.extend([idx_to_label[t] for t in true])\n",
254
+ "\n",
255
+ " return all_true_labels, all_predictions\n"
256
+ ],
257
+ "metadata": {
258
+ "id": "lh4HFt20go8T"
259
+ },
260
+ "execution_count": 6,
261
+ "outputs": []
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "source": [
266
+ "# Initialize model and DataLoader instances\n",
267
+ "vocab_size = len(word_to_idx)\n",
268
+ "tagset_size = len(label_to_idx)\n",
269
+ "model = BiLSTM_NER(vocab_size, tagset_size)\n",
270
+ "\n",
271
+ "# Train the model and display training progress\n",
272
+ "training_progress_df = train_model(model, train_loader, val_loader)\n",
273
+ "\n",
274
+ "# Evaluate on test data\n",
275
+ "true_labels, predicted_labels = evaluate_model(model, val_loader, idx_to_label)\n",
276
+ "print(classification_report(true_labels, predicted_labels, labels=list(label_to_idx.keys()), zero_division=0))\n"
277
+ ],
278
+ "metadata": {
279
+ "colab": {
280
+ "base_uri": "https://localhost:8080/"
281
+ },
282
+ "id": "YH6j-0n7go5a",
283
+ "outputId": "25936497-94b0-4691-86e1-b44162c89005"
284
+ },
285
+ "execution_count": 7,
286
+ "outputs": [
287
+ {
288
+ "output_type": "stream",
289
+ "name": "stdout",
290
+ "text": [
291
+ "\n",
292
+ "Training Progress\n",
293
+ " Epoch Loss Precision Recall F1-score\n",
294
+ " Epoch 1/10 1.616464 0.333333 0.5 0.396825\n",
295
+ " Epoch 2/10 1.577114 0.250000 0.5 0.333333\n",
296
+ " Epoch 3/10 1.519056 0.250000 0.5 0.333333\n",
297
+ " Epoch 4/10 1.438615 0.250000 0.5 0.333333\n",
298
+ " Epoch 5/10 1.365465 0.250000 0.5 0.333333\n",
299
+ " Epoch 6/10 1.290568 0.250000 0.5 0.333333\n",
300
+ " Epoch 7/10 1.226007 0.250000 0.5 0.333333\n",
301
+ " Epoch 8/10 1.162358 0.250000 0.5 0.333333\n",
302
+ " Epoch 9/10 1.107923 0.250000 0.5 0.333333\n",
303
+ "Epoch 10/10 1.051664 0.250000 0.5 0.333333\n",
304
+ " precision recall f1-score support\n",
305
+ "\n",
306
+ " B-PERSON 0.00 0.00 0.00 1\n",
307
+ " I-PERSON 0.00 0.00 0.00 0\n",
308
+ " B-LOCATION 0.00 0.00 0.00 1\n",
309
+ "B-ORGANIZATION 0.00 0.00 0.00 1\n",
310
+ " O 0.50 1.00 0.67 3\n",
311
+ "\n",
312
+ " accuracy 0.50 6\n",
313
+ " macro avg 0.10 0.20 0.13 6\n",
314
+ " weighted avg 0.25 0.50 0.33 6\n",
315
+ "\n"
316
+ ]
317
+ }
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "source": [],
323
+ "metadata": {
324
+ "id": "VVneEo1Ygo2v"
325
+ },
326
+ "execution_count": 7,
327
+ "outputs": []
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "source": [],
332
+ "metadata": {
333
+ "id": "9CU9Qp5ugoz-"
334
+ },
335
+ "execution_count": 7,
336
+ "outputs": []
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "source": [],
341
+ "metadata": {
342
+ "id": "m9bsMovcgox8"
343
+ },
344
+ "execution_count": 7,
345
+ "outputs": []
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "source": [],
350
+ "metadata": {
351
+ "id": "-5-ErtI0gou6"
352
+ },
353
+ "execution_count": 7,
354
+ "outputs": []
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "source": [],
359
+ "metadata": {
360
+ "id": "qAZWIPZ9gosZ"
361
+ },
362
+ "execution_count": 7,
363
+ "outputs": []
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "source": [],
368
+ "metadata": {
369
+ "id": "rpenB8bDgopn"
370
+ },
371
+ "execution_count": 7,
372
+ "outputs": []
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "source": [],
377
+ "metadata": {
378
+ "id": "c4j_rWc9gom9"
379
+ },
380
+ "execution_count": 7,
381
+ "outputs": []
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "source": [],
386
+ "metadata": {
387
+ "id": "Mg1R4n2Ygoke"
388
+ },
389
+ "execution_count": 7,
390
+ "outputs": []
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "source": [],
395
+ "metadata": {
396
+ "id": "LemxYPend6X1"
397
+ },
398
+ "execution_count": 7,
399
+ "outputs": []
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "source": [],
404
+ "metadata": {
405
+ "id": "LZXLa4KWd6U7"
406
+ },
407
+ "execution_count": 7,
408
+ "outputs": []
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "source": [],
413
+ "metadata": {
414
+ "id": "pT2qxBR9d6SR"
415
+ },
416
+ "execution_count": 7,
417
+ "outputs": []
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "source": [],
422
+ "metadata": {
423
+ "id": "1UvYkxq1d6O5"
424
+ },
425
+ "execution_count": 7,
426
+ "outputs": []
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "source": [],
431
+ "metadata": {
432
+ "id": "5BEpFEOiTF-a"
433
+ },
434
+ "execution_count": 7,
435
+ "outputs": []
436
+ }
437
+ ]
438
+ }