BecomeAllan
commited on
Commit
·
8bf76cf
1
Parent(s):
9c0c4aa
update funs
Browse files- .vscode/settings.json +7 -0
- ML_SLRC.py +382 -44
- Util_funs.py +305 -418
.vscode/settings.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"workbench.colorCustomizations": {
|
| 3 |
+
"activityBar.background": "#093518",
|
| 4 |
+
"titleBar.activeBackground": "#0D4A21",
|
| 5 |
+
"titleBar.activeForeground": "#F3FDF6"
|
| 6 |
+
}
|
| 7 |
+
}
|
ML_SLRC.py
CHANGED
|
@@ -1,33 +1,18 @@
|
|
| 1 |
-
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import math
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
| 6 |
-
|
| 7 |
-
import time
|
| 8 |
-
import transformers
|
| 9 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 10 |
-
from sklearn.manifold import TSNE
|
| 11 |
-
from copy import deepcopy, copy
|
| 12 |
-
import seaborn as sns
|
| 13 |
-
import matplotlib.pylab as plt
|
| 14 |
-
from pprint import pprint
|
| 15 |
-
import shutil
|
| 16 |
-
import datetime
|
| 17 |
import re
|
| 18 |
-
import json
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import torch.nn as nn
|
| 23 |
-
from torch.utils.data import Dataset, DataLoader
|
| 24 |
import unicodedata
|
| 25 |
-
import
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
import torch
|
| 28 |
-
import
|
| 29 |
-
from
|
| 30 |
-
|
| 31 |
|
| 32 |
|
| 33 |
# Pre-trained model
|
|
@@ -117,7 +102,6 @@ class SLR_Classifier(nn.Module):
|
|
| 117 |
|
| 118 |
return [loss, [feature, logit], predict]
|
| 119 |
|
| 120 |
-
|
| 121 |
# Undesirable patterns within texts
|
| 122 |
patterns = {
|
| 123 |
'CONCLUSIONS AND IMPLICATIONS':'',
|
|
@@ -157,27 +141,50 @@ patterns = {
|
|
| 157 |
'</p>':'',
|
| 158 |
'<<ETX>>':'',
|
| 159 |
'+/-':'',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
}
|
| 161 |
|
| 162 |
patterns = {x.lower():y for x,y in patterns.items()}
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
class SLR_DataSet(Dataset):
|
| 167 |
-
def __init__(self, **args):
|
| 168 |
self.tokenizer = args.get('tokenizer')
|
| 169 |
self.data = args.get('data')
|
| 170 |
self.max_seq_length = args.get("max_seq_length", 512)
|
| 171 |
self.INPUT_NAME = args.get("input", 'x')
|
| 172 |
self.LABEL_NAME = args.get("output", 'y')
|
|
|
|
| 173 |
|
| 174 |
# Tokenizing and processing text
|
| 175 |
def encode_text(self, example):
|
| 176 |
comment_text = example[self.INPUT_NAME]
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
try:
|
| 180 |
-
labels = LABEL_MAP[example[self.LABEL_NAME]]
|
| 181 |
except:
|
| 182 |
labels = -1
|
| 183 |
|
|
@@ -200,15 +207,6 @@ class SLR_DataSet(Dataset):
|
|
| 200 |
torch.tensor([torch.tensor(labels).to(int)])
|
| 201 |
))
|
| 202 |
|
| 203 |
-
# Text processing function
|
| 204 |
-
def treat_text(self, text):
|
| 205 |
-
text = unicodedata.normalize("NFKD",str(text))
|
| 206 |
-
text = multiple_replace(patterns,text.lower())
|
| 207 |
-
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
| 208 |
-
text = re.sub('( +)',' ', text)
|
| 209 |
-
text = re.sub('(, ,)|(,,)',',', text)
|
| 210 |
-
text = re.sub('(%)|(per cent)',' percent', text)
|
| 211 |
-
return text
|
| 212 |
|
| 213 |
def __len__(self):
|
| 214 |
return len(self.data)
|
|
@@ -221,6 +219,350 @@ class SLR_DataSet(Dataset):
|
|
| 221 |
return temp_data
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
# Regex multiple replace function
|
| 226 |
def multiple_replace(dict, text):
|
|
@@ -229,8 +571,4 @@ def multiple_replace(dict, text):
|
|
| 229 |
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
| 230 |
|
| 231 |
# Substitution
|
| 232 |
-
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
| 233 |
-
|
| 234 |
-
# Undesirable patterns within texts
|
| 235 |
-
|
| 236 |
-
|
|
|
|
| 1 |
+
from torch import nn
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
+
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import unicodedata
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader,TensorDataset, RandomSampler
|
| 8 |
+
from sklearn.model_selection import train_test_split
|
| 9 |
+
from torch.optim import Adam
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import gc
|
| 12 |
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from torchmetrics import functional as fn
|
| 15 |
+
import random
|
| 16 |
|
| 17 |
|
| 18 |
# Pre-trained model
|
|
|
|
| 102 |
|
| 103 |
return [loss, [feature, logit], predict]
|
| 104 |
|
|
|
|
| 105 |
# Undesirable patterns within texts
|
| 106 |
patterns = {
|
| 107 |
'CONCLUSIONS AND IMPLICATIONS':'',
|
|
|
|
| 141 |
'</p>':'',
|
| 142 |
'<<ETX>>':'',
|
| 143 |
'+/-':'',
|
| 144 |
+
'\(.+\)':'',
|
| 145 |
+
'\[.+\]':'',
|
| 146 |
+
' \d ':'',
|
| 147 |
+
'<':'',
|
| 148 |
+
'>':'',
|
| 149 |
+
'- ':'',
|
| 150 |
+
' +':' ',
|
| 151 |
+
', ,':',',
|
| 152 |
+
',,':',',
|
| 153 |
+
'%':' percent',
|
| 154 |
+
'per cent':' percent'
|
| 155 |
}
|
| 156 |
|
| 157 |
patterns = {x.lower():y for x,y in patterns.items()}
|
| 158 |
|
| 159 |
+
|
| 160 |
+
LABEL_MAP = {'negative': 0,
|
| 161 |
+
'not included':0,
|
| 162 |
+
'0':0,
|
| 163 |
+
0:0,
|
| 164 |
+
'excluded':0,
|
| 165 |
+
'positive': 1,
|
| 166 |
+
'included':1,
|
| 167 |
+
'1':1,
|
| 168 |
+
1:1,
|
| 169 |
+
}
|
| 170 |
|
| 171 |
class SLR_DataSet(Dataset):
|
| 172 |
+
def __init__(self,treat_text =None, **args):
|
| 173 |
self.tokenizer = args.get('tokenizer')
|
| 174 |
self.data = args.get('data')
|
| 175 |
self.max_seq_length = args.get("max_seq_length", 512)
|
| 176 |
self.INPUT_NAME = args.get("input", 'x')
|
| 177 |
self.LABEL_NAME = args.get("output", 'y')
|
| 178 |
+
self.treat_text = treat_text
|
| 179 |
|
| 180 |
# Tokenizing and processing text
|
| 181 |
def encode_text(self, example):
|
| 182 |
comment_text = example[self.INPUT_NAME]
|
| 183 |
+
if self.treat_text:
|
| 184 |
+
comment_text = self.treat_text(comment_text)
|
| 185 |
|
| 186 |
try:
|
| 187 |
+
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
|
| 188 |
except:
|
| 189 |
labels = -1
|
| 190 |
|
|
|
|
| 207 |
torch.tensor([torch.tensor(labels).to(int)])
|
| 208 |
))
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
def __len__(self):
|
| 212 |
return len(self.data)
|
|
|
|
| 219 |
return temp_data
|
| 220 |
|
| 221 |
|
| 222 |
+
class Learner(nn.Module):
|
| 223 |
+
|
| 224 |
+
def __init__(self, **args):
|
| 225 |
+
"""
|
| 226 |
+
:param args:
|
| 227 |
+
"""
|
| 228 |
+
super(Learner, self).__init__()
|
| 229 |
+
|
| 230 |
+
self.inner_print = args.get('inner_print')
|
| 231 |
+
self.inner_batch_size = args.get('inner_batch_size')
|
| 232 |
+
self.outer_update_lr = args.get('outer_update_lr')
|
| 233 |
+
self.inner_update_lr = args.get('inner_update_lr')
|
| 234 |
+
self.inner_update_step = args.get('inner_update_step')
|
| 235 |
+
self.inner_update_step_eval = args.get('inner_update_step_eval')
|
| 236 |
+
self.model = args.get('model')
|
| 237 |
+
self.device = args.get('device')
|
| 238 |
+
|
| 239 |
+
# Outer optimizer
|
| 240 |
+
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
|
| 241 |
+
self.model.train()
|
| 242 |
+
|
| 243 |
+
def forward(self, batch_tasks, training = True, valid_train = True):
|
| 244 |
+
"""
|
| 245 |
+
batch = [(support TensorDataset, query TensorDataset),
|
| 246 |
+
(support TensorDataset, query TensorDataset),
|
| 247 |
+
(support TensorDataset, query TensorDataset),
|
| 248 |
+
(support TensorDataset, query TensorDataset)]
|
| 249 |
+
|
| 250 |
+
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
|
| 251 |
+
"""
|
| 252 |
+
task_accs = []
|
| 253 |
+
task_f1 = []
|
| 254 |
+
task_recall = []
|
| 255 |
+
sum_gradients = []
|
| 256 |
+
num_task = len(batch_tasks)
|
| 257 |
+
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
|
| 258 |
+
|
| 259 |
+
# Outer loop tasks
|
| 260 |
+
for task_id, task in enumerate(batch_tasks):
|
| 261 |
+
support = task[0]
|
| 262 |
+
query = task[1]
|
| 263 |
+
name = task[2]
|
| 264 |
+
|
| 265 |
+
# Copying model
|
| 266 |
+
fast_model = deepcopy(self.model)
|
| 267 |
+
fast_model.to(self.device)
|
| 268 |
+
|
| 269 |
+
# Inner trainer optimizer
|
| 270 |
+
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
|
| 271 |
+
|
| 272 |
+
# Creating training data loaders
|
| 273 |
+
if len(support) % self.inner_batch_size == 1 :
|
| 274 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
| 275 |
+
batch_size=self.inner_batch_size,
|
| 276 |
+
drop_last=True)
|
| 277 |
+
else:
|
| 278 |
+
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
|
| 279 |
+
batch_size=self.inner_batch_size,
|
| 280 |
+
drop_last=False)
|
| 281 |
+
|
| 282 |
+
# steps_per_epoch=len(support) // self.inner_batch_size
|
| 283 |
+
# total_training_steps = steps_per_epoch * 5
|
| 284 |
+
# warmup_steps = total_training_steps // 3
|
| 285 |
+
#
|
| 286 |
+
|
| 287 |
+
# scheduler = get_linear_schedule_with_warmup(
|
| 288 |
+
# inner_optimizer,
|
| 289 |
+
# num_warmup_steps=warmup_steps,
|
| 290 |
+
# num_training_steps=total_training_steps
|
| 291 |
+
# )
|
| 292 |
+
|
| 293 |
+
fast_model.train()
|
| 294 |
+
|
| 295 |
+
# Inner loop training epoch (support set)
|
| 296 |
+
if valid_train:
|
| 297 |
+
print('----Task',task_id,":", name, '----')
|
| 298 |
+
|
| 299 |
+
for i in range(0, num_inner_update_step):
|
| 300 |
+
all_loss = []
|
| 301 |
+
|
| 302 |
+
# Inner loop training batch (support set)
|
| 303 |
+
for inner_step, batch in enumerate(support_dataloader):
|
| 304 |
+
batch = tuple(t.to(self.device) for t in batch)
|
| 305 |
+
input_ids, attention_mask, token_type_ids, label_id = batch
|
| 306 |
+
|
| 307 |
+
# Feed Foward
|
| 308 |
+
loss, _, _ = fast_model(input_ids, attention_mask, token_type_ids=token_type_ids, labels = label_id)
|
| 309 |
+
|
| 310 |
+
# Computing gradients
|
| 311 |
+
loss.backward()
|
| 312 |
+
# torch.nn.utils.clip_grad_norm_(fast_model.parameters(), max_norm=1)
|
| 313 |
+
|
| 314 |
+
# Updating inner training parameters
|
| 315 |
+
inner_optimizer.step()
|
| 316 |
+
inner_optimizer.zero_grad()
|
| 317 |
+
|
| 318 |
+
# Appending losses
|
| 319 |
+
all_loss.append(loss.item())
|
| 320 |
+
|
| 321 |
+
del batch, input_ids, attention_mask, label_id
|
| 322 |
+
torch.cuda.empty_cache()
|
| 323 |
+
|
| 324 |
+
if valid_train:
|
| 325 |
+
if (i+1) % self.inner_print == 0:
|
| 326 |
+
print("Inner Loss: ", np.mean(all_loss))
|
| 327 |
+
|
| 328 |
+
fast_model.to(torch.device('cpu'))
|
| 329 |
+
|
| 330 |
+
# Inner training phase weights
|
| 331 |
+
if training:
|
| 332 |
+
meta_weights = list(self.model.parameters())
|
| 333 |
+
fast_weights = list(fast_model.parameters())
|
| 334 |
+
|
| 335 |
+
# Appending gradients
|
| 336 |
+
gradients = []
|
| 337 |
+
for i, (meta_params, fast_params) in enumerate(zip(meta_weights, fast_weights)):
|
| 338 |
+
gradient = meta_params - fast_params
|
| 339 |
+
if task_id == 0:
|
| 340 |
+
sum_gradients.append(gradient)
|
| 341 |
+
else:
|
| 342 |
+
sum_gradients[i] += gradient
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# Inner test (query set)
|
| 346 |
+
fast_model.to(self.device)
|
| 347 |
+
fast_model.eval()
|
| 348 |
+
|
| 349 |
+
if valid_train:
|
| 350 |
+
# Inner test (query set)
|
| 351 |
+
fast_model.to(self.device)
|
| 352 |
+
fast_model.eval()
|
| 353 |
+
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
# Data loader
|
| 356 |
+
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
|
| 357 |
+
query_batch = iter(query_dataloader).next()
|
| 358 |
+
query_batch = tuple(t.to(self.device) for t in query_batch)
|
| 359 |
+
q_input_ids, q_attention_mask, q_token_type_ids, q_label_id = query_batch
|
| 360 |
+
|
| 361 |
+
# Feedfoward
|
| 362 |
+
_, _, pre_label_id = fast_model(q_input_ids, q_attention_mask, q_token_type_ids, labels = q_label_id)
|
| 363 |
+
|
| 364 |
+
# Predictions
|
| 365 |
+
pre_label_id = pre_label_id.detach().cpu().squeeze()
|
| 366 |
+
# Labels
|
| 367 |
+
q_label_id = q_label_id.detach().cpu()
|
| 368 |
+
|
| 369 |
+
# Calculating metrics
|
| 370 |
+
acc = fn.accuracy(pre_label_id, q_label_id).item()
|
| 371 |
+
recall = fn.recall(pre_label_id, q_label_id).item(),
|
| 372 |
+
f1 = fn.f1_score(pre_label_id, q_label_id).item()
|
| 373 |
+
|
| 374 |
+
# appending metrics
|
| 375 |
+
task_accs.append(acc)
|
| 376 |
+
task_f1.append(f1)
|
| 377 |
+
task_recall.append(recall)
|
| 378 |
+
|
| 379 |
+
fast_model.to(torch.device('cpu'))
|
| 380 |
+
|
| 381 |
+
del fast_model, inner_optimizer
|
| 382 |
+
torch.cuda.empty_cache()
|
| 383 |
+
|
| 384 |
+
print("\n")
|
| 385 |
+
print("f1:",np.mean(task_f1))
|
| 386 |
+
print("recall:",np.mean(task_recall))
|
| 387 |
+
|
| 388 |
+
# Updating outer training parameters
|
| 389 |
+
if training:
|
| 390 |
+
# Mean of gradients
|
| 391 |
+
for i in range(0,len(sum_gradients)):
|
| 392 |
+
sum_gradients[i] = sum_gradients[i] / float(num_task)
|
| 393 |
+
|
| 394 |
+
# Indexing parameters to model
|
| 395 |
+
for i, params in enumerate(self.model.parameters()):
|
| 396 |
+
params.grad = sum_gradients[i]
|
| 397 |
+
|
| 398 |
+
# Updating parameters
|
| 399 |
+
self.outer_optimizer.step()
|
| 400 |
+
self.outer_optimizer.zero_grad()
|
| 401 |
+
|
| 402 |
+
del sum_gradients
|
| 403 |
+
gc.collect()
|
| 404 |
+
torch.cuda.empty_cache()
|
| 405 |
+
|
| 406 |
+
if valid_train:
|
| 407 |
+
return np.mean(task_accs)
|
| 408 |
+
else:
|
| 409 |
+
return np.array(0)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# Creating Meta Tasks
|
| 414 |
+
class MetaTask(Dataset):
|
| 415 |
+
def __init__(self, examples, num_task, k_support, k_query,
|
| 416 |
+
tokenizer, training=True, max_seq_length=512,
|
| 417 |
+
treat_text =None, **args):
|
| 418 |
+
"""
|
| 419 |
+
:param samples: list of samples
|
| 420 |
+
:param num_task: number of training tasks.
|
| 421 |
+
:param k_support: number of classes support samples per task
|
| 422 |
+
:param k_query: number of classes query sample per task
|
| 423 |
+
"""
|
| 424 |
+
self.examples = examples
|
| 425 |
+
|
| 426 |
+
self.num_task = num_task
|
| 427 |
+
self.k_support = k_support
|
| 428 |
+
self.k_query = k_query
|
| 429 |
+
self.tokenizer = tokenizer
|
| 430 |
+
self.max_seq_length = max_seq_length
|
| 431 |
+
self.treat_text = treat_text
|
| 432 |
+
|
| 433 |
+
# Randomly generating tasks
|
| 434 |
+
self.create_batch(self.num_task, training)
|
| 435 |
+
|
| 436 |
+
# Creating batch
|
| 437 |
+
def create_batch(self, num_task, training):
|
| 438 |
+
self.supports = [] # support set
|
| 439 |
+
self.queries = [] # query set
|
| 440 |
+
self.task_names = [] # Name of task
|
| 441 |
+
self.supports_indexs = [] # index of supports
|
| 442 |
+
self.queries_indexs = [] # index of queries
|
| 443 |
+
self.num_task=num_task
|
| 444 |
+
|
| 445 |
+
# Available tasks
|
| 446 |
+
domains = self.examples['domain'].unique()
|
| 447 |
+
|
| 448 |
+
# If not training, create all tasks
|
| 449 |
+
if not(training):
|
| 450 |
+
self.task_names = domains
|
| 451 |
+
num_task = len(self.task_names)
|
| 452 |
+
self.num_task=num_task
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
for b in range(num_task): # For each task,
|
| 456 |
+
total_per_class = self.k_support + self.k_query
|
| 457 |
+
task_size = 2*self.k_support + 2*self.k_query
|
| 458 |
+
|
| 459 |
+
# Select a task at random
|
| 460 |
+
if training:
|
| 461 |
+
domain = random.choice(domains)
|
| 462 |
+
self.task_names.append(domain)
|
| 463 |
+
else:
|
| 464 |
+
domain = self.task_names[b]
|
| 465 |
+
|
| 466 |
+
# Task data
|
| 467 |
+
domainExamples = self.examples[self.examples['domain'] == domain]
|
| 468 |
+
|
| 469 |
+
# Minimal label quantity
|
| 470 |
+
min_per_class = min(domainExamples['label'].value_counts())
|
| 471 |
+
|
| 472 |
+
if total_per_class > min_per_class:
|
| 473 |
+
total_per_class = min_per_class
|
| 474 |
+
|
| 475 |
+
# Select k_support + k_query task examples
|
| 476 |
+
# Sample (n) from each label(class)
|
| 477 |
+
selected_examples = domainExamples.groupby("label").sample(total_per_class, replace = False)
|
| 478 |
+
|
| 479 |
+
# Split data into support (training) and query (testing) sets
|
| 480 |
+
s, q = train_test_split(selected_examples,
|
| 481 |
+
stratify= selected_examples["label"],
|
| 482 |
+
test_size= 2*self.k_query/task_size,
|
| 483 |
+
shuffle=True)
|
| 484 |
+
|
| 485 |
+
# Permutating data
|
| 486 |
+
s = s.sample(frac=1)
|
| 487 |
+
q = q.sample(frac=1)
|
| 488 |
+
|
| 489 |
+
# Appending indexes
|
| 490 |
+
if not(training):
|
| 491 |
+
self.supports_indexs.append(s.index)
|
| 492 |
+
self.queries_indexs.append(q.index)
|
| 493 |
+
|
| 494 |
+
# Creating list of support (training) and query (testing) tasks
|
| 495 |
+
self.supports.append(s.to_dict('records'))
|
| 496 |
+
self.queries.append(q.to_dict('records'))
|
| 497 |
+
|
| 498 |
+
# Creating task tensors
|
| 499 |
+
def create_feature_set(self, examples):
|
| 500 |
+
all_input_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
| 501 |
+
all_attention_mask = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
| 502 |
+
all_token_type_ids = torch.empty(len(examples), self.max_seq_length, dtype = torch.long)
|
| 503 |
+
all_label_ids = torch.empty(len(examples), dtype = torch.long)
|
| 504 |
+
|
| 505 |
+
for _id, e in enumerate(examples):
|
| 506 |
+
all_input_ids[_id], all_attention_mask[_id], all_token_type_ids[_id], all_label_ids[_id] = self.encode_text(e)
|
| 507 |
+
|
| 508 |
+
return TensorDataset(
|
| 509 |
+
all_input_ids,
|
| 510 |
+
all_attention_mask,
|
| 511 |
+
all_token_type_ids,
|
| 512 |
+
all_label_ids
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# Data encoding
|
| 516 |
+
def encode_text(self, example):
|
| 517 |
+
comment_text = example["text"]
|
| 518 |
+
|
| 519 |
+
if self.treat_text:
|
| 520 |
+
comment_text = self.treat_text(comment_text)
|
| 521 |
+
|
| 522 |
+
labels = LABEL_MAP[example["label"]]
|
| 523 |
+
|
| 524 |
+
encoding = self.tokenizer.encode_plus(
|
| 525 |
+
(comment_text, "It is a great text."),
|
| 526 |
+
add_special_tokens=True,
|
| 527 |
+
max_length=self.max_seq_length,
|
| 528 |
+
return_token_type_ids=True,
|
| 529 |
+
padding="max_length",
|
| 530 |
+
truncation=True,
|
| 531 |
+
return_attention_mask=True,
|
| 532 |
+
return_tensors='pt',
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
return tuple((
|
| 536 |
+
encoding["input_ids"].flatten(),
|
| 537 |
+
encoding["attention_mask"].flatten(),
|
| 538 |
+
encoding["token_type_ids"].flatten(),
|
| 539 |
+
torch.tensor([torch.tensor(labels).to(int)])
|
| 540 |
+
))
|
| 541 |
+
|
| 542 |
+
# Returns data upon calling
|
| 543 |
+
def __getitem__(self, index):
|
| 544 |
+
support_set = self.create_feature_set(self.supports[index])
|
| 545 |
+
query_set = self.create_feature_set(self.queries[index])
|
| 546 |
+
name = self.task_names[index]
|
| 547 |
+
return support_set, query_set, name
|
| 548 |
+
|
| 549 |
+
def __len__(self):
|
| 550 |
+
return self.num_task
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class treat_text:
|
| 554 |
+
def __init__(self, patterns):
|
| 555 |
+
self.patterns = patterns
|
| 556 |
+
|
| 557 |
+
def __call__(self,text):
|
| 558 |
+
text = unicodedata.normalize("NFKD",str(text))
|
| 559 |
+
text = multiple_replace(self.patterns,text.lower())
|
| 560 |
+
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
| 561 |
+
text = re.sub('( +)',' ', text)
|
| 562 |
+
text = re.sub('(, ,)|(,,)',',', text)
|
| 563 |
+
text = re.sub('(%)|(per cent)',' percent', text)
|
| 564 |
+
return text
|
| 565 |
+
|
| 566 |
|
| 567 |
# Regex multiple replace function
|
| 568 |
def multiple_replace(dict, text):
|
|
|
|
| 571 |
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
| 572 |
|
| 573 |
# Substitution
|
| 574 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
|
|
|
|
|
|
|
|
|
|
|
Util_funs.py
CHANGED
|
@@ -1,49 +1,49 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import torch
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import random
|
| 5 |
-
import json, pickle
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import torch
|
| 11 |
-
|
| 12 |
-
import pandas as pd
|
| 13 |
import time
|
| 14 |
-
import transformers
|
| 15 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 16 |
from sklearn.manifold import TSNE
|
| 17 |
-
from copy import deepcopy
|
| 18 |
import seaborn as sns
|
| 19 |
import matplotlib.pylab as plt
|
| 20 |
-
from pprint import pprint
|
| 21 |
-
import shutil
|
| 22 |
-
import datetime
|
| 23 |
-
import re
|
| 24 |
import json
|
| 25 |
from pathlib import Path
|
| 26 |
-
|
| 27 |
-
import
|
| 28 |
-
from
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
from transformers import BertForSequenceClassification
|
| 35 |
-
from copy import deepcopy
|
| 36 |
-
import gc
|
| 37 |
-
from sklearn.metrics import accuracy_score
|
| 38 |
-
import torch
|
| 39 |
-
import numpy as np
|
| 40 |
-
import torchmetrics
|
| 41 |
-
from torchmetrics import functional as fn
|
| 42 |
|
| 43 |
|
| 44 |
-
SEED = 2222
|
| 45 |
|
| 46 |
-
gen_seed = torch.Generator().manual_seed(SEED)
|
| 47 |
|
| 48 |
|
| 49 |
# Random seed function
|
|
@@ -54,7 +54,7 @@ def random_seed(value):
|
|
| 54 |
np.random.seed(value)
|
| 55 |
random.seed(value)
|
| 56 |
|
| 57 |
-
#
|
| 58 |
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
| 59 |
idxs = list(range(0,len(taskset)))
|
| 60 |
if is_shuffle:
|
|
@@ -63,48 +63,51 @@ def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
|
| 63 |
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
|
| 64 |
|
| 65 |
|
| 66 |
-
|
| 67 |
-
def prepare_data(data, batch_size,tokenizer,max_seq_length,
|
| 68 |
input = 'text', output = 'label',
|
| 69 |
-
train_size_per_class = 5
|
|
|
|
| 70 |
data = data.reset_index().drop("index", axis=1)
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
data_train = labaled_data.groupby('label').sample(train_size_per_class)
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
##
|
| 84 |
dataset_train = SLR_DataSet(
|
| 85 |
data = data_train.sample(frac=1),
|
| 86 |
input = input,
|
| 87 |
output = output,
|
| 88 |
tokenizer=tokenizer,
|
| 89 |
-
max_seq_length =max_seq_length
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
# Dataloaders
|
| 93 |
-
## Transforma em dataset
|
| 94 |
dataset_test = SLR_DataSet(
|
| 95 |
data = data_test,
|
| 96 |
input = input,
|
| 97 |
output = output,
|
| 98 |
tokenizer=tokenizer,
|
| 99 |
-
max_seq_length =max_seq_length
|
|
|
|
| 100 |
|
| 101 |
# Dataloaders
|
| 102 |
-
##
|
| 103 |
data_train_loader = DataLoader(dataset_train,
|
| 104 |
shuffle=True,
|
| 105 |
batch_size=batch_size['train']
|
| 106 |
)
|
| 107 |
|
|
|
|
| 108 |
if len(dataset_test) % batch_size['test'] == 1 :
|
| 109 |
data_test_loader = DataLoader(dataset_test,
|
| 110 |
batch_size=batch_size['test'],
|
|
@@ -117,50 +120,54 @@ def prepare_data(data, batch_size,tokenizer,max_seq_length,
|
|
| 117 |
return data_train_loader, data_test_loader, data_train, data_test
|
| 118 |
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
from tqdm import tqdm
|
| 124 |
-
|
| 125 |
-
def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_resource =None):
|
| 126 |
-
|
| 127 |
learner = Learner(model = model, device = device, **Info)
|
| 128 |
|
| 129 |
# Testing tasks
|
| 130 |
if isinstance(Test_resource, pd.DataFrame):
|
| 131 |
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
|
| 132 |
-
training=False, **Info)
|
| 133 |
|
| 134 |
|
| 135 |
torch.clear_autocast_cache()
|
| 136 |
gc.collect()
|
| 137 |
torch.cuda.empty_cache()
|
| 138 |
|
| 139 |
-
# Meta
|
| 140 |
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
|
| 141 |
-
# print("Meta Epoca:", epoch)
|
| 142 |
|
| 143 |
-
#
|
| 144 |
train = MetaTask(data,
|
| 145 |
num_task = Info['num_task_train'],
|
| 146 |
k_support=Info['k_qry'],
|
| 147 |
-
k_query=Info['k_spt'],
|
|
|
|
| 148 |
|
| 149 |
-
#
|
| 150 |
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
|
| 151 |
|
| 152 |
if print_epoch:
|
| 153 |
# Outer loop bach training
|
| 154 |
for step, task_batch in enumerate(db):
|
| 155 |
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
|
| 156 |
-
|
|
|
|
| 157 |
acc = learner(task_batch, valid_train= print_epoch)
|
| 158 |
print('Step:', step, '\ttraining Acc:', acc)
|
|
|
|
| 159 |
if isinstance(Test_resource, pd.DataFrame):
|
| 160 |
-
# Validating Model
|
| 161 |
if ((epoch+1) % 4) + step == 0:
|
| 162 |
random_seed(123)
|
| 163 |
print("\n-----------------Testing Mode-----------------\n")
|
|
|
|
|
|
|
| 164 |
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
|
| 165 |
acc_all_test = []
|
| 166 |
|
|
@@ -174,10 +181,10 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
|
|
| 174 |
|
| 175 |
# Restarting training randomly
|
| 176 |
random_seed(int(time.time() % 10))
|
| 177 |
-
|
| 178 |
-
|
| 179 |
else:
|
| 180 |
for step, task_batch in enumerate(db):
|
|
|
|
| 181 |
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
|
| 182 |
|
| 183 |
torch.clear_autocast_cache()
|
|
@@ -187,14 +194,14 @@ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_
|
|
| 187 |
|
| 188 |
|
| 189 |
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
|
| 190 |
-
#
|
| 191 |
model_meta = deepcopy(model)
|
| 192 |
optimizer = Adam(model_meta.parameters(), lr=lr)
|
| 193 |
|
| 194 |
model_meta.to(device)
|
| 195 |
model_meta.train()
|
| 196 |
|
| 197 |
-
#
|
| 198 |
for i in range(0, epoch):
|
| 199 |
all_loss = []
|
| 200 |
|
|
@@ -203,13 +210,13 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
| 203 |
batch = tuple(t.to(device) for t in batch)
|
| 204 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
| 205 |
|
| 206 |
-
# Feedfoward
|
| 207 |
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
| 208 |
|
| 209 |
-
#
|
| 210 |
loss.backward()
|
| 211 |
|
| 212 |
-
#
|
| 213 |
optimizer.step()
|
| 214 |
optimizer.zero_grad()
|
| 215 |
|
|
@@ -220,39 +227,43 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
| 220 |
print("Loss: ", np.mean(all_loss))
|
| 221 |
|
| 222 |
|
| 223 |
-
#
|
| 224 |
model_meta.eval()
|
| 225 |
all_loss = []
|
| 226 |
-
|
| 227 |
features = []
|
| 228 |
labels = []
|
| 229 |
predi_logit = []
|
| 230 |
|
| 231 |
with torch.no_grad():
|
|
|
|
| 232 |
for inner_step, batch in enumerate(tqdm(data_test_loader,
|
| 233 |
desc="Test validation | " + name,
|
| 234 |
ncols=80)) :
|
| 235 |
batch = tuple(t.to(device) for t in batch)
|
| 236 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
| 237 |
|
| 238 |
-
#
|
| 239 |
_, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
| 240 |
|
|
|
|
| 241 |
prediction = prediction.detach().cpu().squeeze()
|
| 242 |
label_id = label_id.detach().cpu()
|
|
|
|
|
|
|
| 243 |
logit = feature[1].detach().cpu()
|
| 244 |
-
|
| 245 |
|
| 246 |
-
|
| 247 |
features.append(feature_lat.numpy())
|
| 248 |
-
predi_logit.append(logit.numpy())
|
| 249 |
|
| 250 |
-
#
|
| 251 |
-
|
|
|
|
| 252 |
del input_ids, attention_mask, label_id, batch
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
|
| 257 |
model_meta.to('cpu')
|
| 258 |
gc.collect()
|
|
@@ -260,26 +271,32 @@ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr
|
|
| 260 |
|
| 261 |
del model_meta, optimizer
|
| 262 |
|
|
|
|
| 263 |
|
|
|
|
|
|
|
|
|
|
| 264 |
features = np.concatenate(np.array(features,dtype=object))
|
| 265 |
-
labels = np.concatenate(np.array(labels,dtype=object))
|
| 266 |
-
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
| 267 |
-
|
| 268 |
features = torch.tensor(features.astype(np.float32)).detach().clone()
|
|
|
|
|
|
|
| 269 |
labels = torch.tensor(labels.astype(int)).detach().clone()
|
|
|
|
|
|
|
| 270 |
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
| 271 |
|
| 272 |
-
#
|
| 273 |
X_embedded = TSNE(n_components=2, learning_rate='auto',
|
| 274 |
init='random').fit_transform(features.detach().clone())
|
| 275 |
|
| 276 |
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
|
| 277 |
-
|
| 278 |
-
|
| 279 |
def wss_calc(logit, labels, trsh = 0.5):
|
| 280 |
|
| 281 |
-
#
|
| 282 |
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
|
|
|
|
|
|
|
| 283 |
CM = confusion_matrix(labels, predict_trash.to(int) )
|
| 284 |
tn, fp, fne, tp = CM.ravel()
|
| 285 |
|
|
@@ -287,36 +304,22 @@ def wss_calc(logit, labels, trsh = 0.5):
|
|
| 287 |
N = (tn + fp)
|
| 288 |
recall = tp/(tp+fne)
|
| 289 |
|
| 290 |
-
#
|
| 291 |
-
|
| 292 |
|
| 293 |
-
#
|
| 294 |
-
|
| 295 |
|
| 296 |
return {
|
| 297 |
-
"wss": round(
|
| 298 |
-
"awss": round(
|
| 299 |
"R": round(recall,4),
|
| 300 |
"CM": CM
|
| 301 |
}
|
| 302 |
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
from sklearn.metrics import confusion_matrix
|
| 307 |
-
from torchmetrics import functional as fn
|
| 308 |
-
import matplotlib.pyplot as plt
|
| 309 |
-
from sklearn.metrics import roc_curve, auc
|
| 310 |
-
from sklearn.metrics import roc_auc_score
|
| 311 |
-
import ipywidgets as widgets
|
| 312 |
-
from IPython.display import HTML, display, clear_output
|
| 313 |
-
import matplotlib.pyplot as plt
|
| 314 |
-
import seaborn as sns
|
| 315 |
-
import warnings
|
| 316 |
-
|
| 317 |
-
warnings.simplefilter(action='ignore', category=FutureWarning)
|
| 318 |
-
|
| 319 |
-
def plot(logits, X_embedded, labels, tresh, show = True,
|
| 320 |
namefig = "plot", make_plot = True, print_stats = True, save = True):
|
| 321 |
col = pd.MultiIndex.from_tuples([
|
| 322 |
("Predict", "0"),
|
|
@@ -329,30 +332,27 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
| 329 |
|
| 330 |
predict = torch.sigmoid(logits).detach().clone()
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
| 334 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
| 335 |
|
| 336 |
-
#
|
| 337 |
-
|
| 338 |
-
|
| 339 |
idx_wss95 = sum(tpr < 0.95)
|
|
|
|
| 340 |
thresholds95 = thresholds[idx_wss95]
|
| 341 |
|
|
|
|
| 342 |
wss95_info = wss_calc(logits,labels, thresholds95 )
|
| 343 |
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
|
| 344 |
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
|
| 345 |
|
| 346 |
|
| 347 |
-
#
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
wss_info = wss_calc(logits,labels, tresh )
|
| 353 |
-
# Accuraci
|
| 354 |
-
acc_wssR = fn.accuracy(predict, labels, threshold=tresh)
|
| 355 |
-
f1_wssR = fn.f1_score(predict, labels, threshold=tresh)
|
| 356 |
|
| 357 |
|
| 358 |
metrics= {
|
|
@@ -370,12 +370,11 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
| 370 |
# f1
|
| 371 |
"f1@95": f1_wss95.item(),
|
| 372 |
"f1@R": f1_wssR.item(),
|
| 373 |
-
#
|
| 374 |
-
"
|
| 375 |
}
|
| 376 |
|
| 377 |
-
#
|
| 378 |
-
|
| 379 |
if print_stats:
|
| 380 |
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
|
| 381 |
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
|
|
@@ -383,14 +382,14 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
| 383 |
print(wss95_adj)
|
| 384 |
print('Acc.:', round(acc_wss95.item(), 4))
|
| 385 |
print('F1-score:', round(f1_wss95.item(), 4))
|
| 386 |
-
print(f"
|
| 387 |
cm = pd.DataFrame(wss95_info['CM'],
|
| 388 |
index=index,
|
| 389 |
columns=col)
|
| 390 |
|
| 391 |
print("\nConfusion matrix:")
|
| 392 |
print(cm)
|
| 393 |
-
print("\n---Metrics with threshold:",
|
| 394 |
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
|
| 395 |
print(wss)
|
| 396 |
wss_adj= f"AWSS@R:{wss_info['awss']}"
|
|
@@ -405,51 +404,53 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
| 405 |
print(cm)
|
| 406 |
|
| 407 |
|
| 408 |
-
#
|
| 409 |
|
| 410 |
if make_plot:
|
| 411 |
|
| 412 |
fig, axes = plt.subplots(1, 4, figsize=(25,10))
|
| 413 |
alpha = torch.squeeze(predict).numpy()
|
| 414 |
|
| 415 |
-
#
|
| 416 |
-
|
| 417 |
p1 = sns.scatterplot(x=X_embedded[:, 0],
|
| 418 |
y=X_embedded[:, 1],
|
| 419 |
hue=labels,
|
| 420 |
-
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE')
|
| 421 |
|
|
|
|
|
|
|
| 422 |
t_wss = predict >= thresholds95
|
| 423 |
t_wss = t_wss.squeeze().numpy()
|
| 424 |
-
|
| 425 |
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
|
| 426 |
y=X_embedded[t_wss, 1],
|
| 427 |
hue=labels[t_wss],
|
| 428 |
-
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95')
|
| 429 |
|
| 430 |
-
|
|
|
|
| 431 |
t = t.squeeze().numpy()
|
| 432 |
-
|
| 433 |
p3 = sns.scatterplot(x=X_embedded[t, 0],
|
| 434 |
y=X_embedded[t, 1],
|
| 435 |
hue=labels[t],
|
| 436 |
-
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-
|
| 437 |
-
|
| 438 |
|
|
|
|
| 439 |
roc_auc = auc(fpr, tpr)
|
| 440 |
lw = 2
|
| 441 |
-
|
| 442 |
axes[3].plot(
|
| 443 |
fpr,
|
| 444 |
tpr,
|
| 445 |
color="darkorange",
|
| 446 |
lw=lw,
|
| 447 |
label="ROC curve (area = %0.2f)" % roc_auc)
|
| 448 |
-
|
| 449 |
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
| 450 |
axes[3].axhline(y=0.95, color='r', linestyle='-')
|
| 451 |
-
axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate"
|
| 452 |
axes[3].legend(loc="lower right")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
|
| 454 |
if show:
|
| 455 |
plt.show()
|
|
@@ -459,6 +460,7 @@ def plot(logits, X_embedded, labels, tresh, show = True,
|
|
| 459 |
|
| 460 |
return metrics
|
| 461 |
|
|
|
|
| 462 |
def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
| 463 |
predict = torch.sigmoid(logits).detach().clone()
|
| 464 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
|
@@ -478,45 +480,40 @@ def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
|
| 478 |
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
| 479 |
plt.axhline(y=0.95, color='r', linestyle='-')
|
| 480 |
|
| 481 |
-
|
| 482 |
-
from sklearn.metrics import confusion_matrix
|
| 483 |
-
from torchmetrics import functional as fn
|
| 484 |
-
import matplotlib.pyplot as plt
|
| 485 |
-
from sklearn.metrics import roc_curve, auc
|
| 486 |
-
from sklearn.metrics import roc_auc_score
|
| 487 |
-
import ipywidgets as widgets
|
| 488 |
-
from IPython.display import HTML, display, clear_output
|
| 489 |
-
import matplotlib.pyplot as plt
|
| 490 |
-
import seaborn as sns
|
| 491 |
-
import warnings
|
| 492 |
-
|
| 493 |
-
|
| 494 |
class diagnosis():
|
| 495 |
-
def __init__(self, names, Valid_resource, batch_size_test,
|
|
|
|
| 496 |
self.names=names
|
| 497 |
self.Valid_resource=Valid_resource
|
| 498 |
self.batch_size_test=batch_size_test
|
| 499 |
self.model=model
|
| 500 |
-
self.start=start
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
|
|
|
| 502 |
self.value_trash = widgets.FloatText(
|
| 503 |
value=0.95,
|
| 504 |
-
description='
|
| 505 |
disabled=False
|
| 506 |
)
|
| 507 |
-
|
| 508 |
self.valueb = widgets.IntText(
|
| 509 |
value=10,
|
| 510 |
description='size',
|
| 511 |
disabled=False
|
| 512 |
)
|
| 513 |
|
|
|
|
| 514 |
self.train_b = widgets.Button(description="Train")
|
| 515 |
self.next_b = widgets.Button(description="Next")
|
| 516 |
self.eval_b = widgets.Button(description="Evaluation")
|
| 517 |
|
| 518 |
self.hbox = widgets.HBox([self.train_b, self.valueb])
|
| 519 |
|
|
|
|
| 520 |
self.next_b.on_click(self.Next_button)
|
| 521 |
self.train_b.on_click(self.Train_button)
|
| 522 |
self.eval_b.on_click(self.Evaluation_button)
|
|
@@ -527,36 +524,37 @@ class diagnosis():
|
|
| 527 |
clear_output()
|
| 528 |
self.i=self.i+1
|
| 529 |
|
| 530 |
-
#
|
| 531 |
-
self.domain = names[self.i]
|
| 532 |
-
print("Name:", self.domain)
|
| 533 |
-
|
| 534 |
-
# global data
|
| 535 |
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
|
|
|
|
|
|
|
| 536 |
print(self.data['label'].value_counts())
|
| 537 |
-
|
| 538 |
display(self.hbox)
|
| 539 |
display(self.next_b)
|
| 540 |
|
|
|
|
| 541 |
# Train button
|
| 542 |
def Train_button(self, y):
|
| 543 |
clear_output()
|
| 544 |
print(self.domain)
|
| 545 |
|
| 546 |
-
#
|
| 547 |
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
|
| 548 |
train_size_per_class = self.valueb.value,
|
| 549 |
-
batch_size = {'train': Info['inner_batch_size'],
|
| 550 |
-
'test': batch_size_test},
|
| 551 |
-
max_seq_length = Info['max_seq_length'],
|
| 552 |
-
tokenizer = Info['tokenizer'],
|
| 553 |
input = "text",
|
| 554 |
-
output = "label"
|
|
|
|
| 555 |
|
|
|
|
| 556 |
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
|
| 557 |
-
model, device,
|
| 558 |
-
epoch = Info['inner_update_step'],
|
| 559 |
-
lr=Info['inner_update_lr'],
|
| 560 |
print_info=True,
|
| 561 |
name = self.domain)
|
| 562 |
|
|
@@ -565,6 +563,7 @@ class diagnosis():
|
|
| 565 |
display(tresh_box)
|
| 566 |
display(self.next_b)
|
| 567 |
|
|
|
|
| 568 |
# Evaluation button
|
| 569 |
def Evaluation_button(self, te):
|
| 570 |
clear_output()
|
|
@@ -573,19 +572,18 @@ class diagnosis():
|
|
| 573 |
print(self.domain)
|
| 574 |
# print("\n")
|
| 575 |
print("-------Train data-------")
|
| 576 |
-
print(
|
| 577 |
print("-------Test data-------")
|
| 578 |
-
print(
|
| 579 |
# print("\n")
|
| 580 |
|
| 581 |
display(self.next_b)
|
| 582 |
display(tresh_box)
|
| 583 |
display(self.hbox)
|
| 584 |
|
| 585 |
-
|
| 586 |
metrics = plot(self.logits, self.X_embedded, self.labels,
|
| 587 |
-
|
| 588 |
-
# namefig= "./"+base_path +"/"+"Results/size_layer/"+ name_domain+'/' +str(n_layers) + '/img/' + str(attempt) + 'plots',
|
| 589 |
namefig= 'test',
|
| 590 |
make_plot = True,
|
| 591 |
print_stats = True,
|
|
@@ -593,261 +591,150 @@ class diagnosis():
|
|
| 593 |
|
| 594 |
def __call__(self):
|
| 595 |
self.i= self.start-1
|
| 596 |
-
|
| 597 |
clear_output()
|
| 598 |
display(self.next_b)
|
| 599 |
|
| 600 |
|
| 601 |
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
import torch.nn.functional as F
|
| 610 |
-
import torch.nn as nn
|
| 611 |
-
import math
|
| 612 |
-
import torch
|
| 613 |
-
import numpy as np
|
| 614 |
-
import pandas as pd
|
| 615 |
-
import time
|
| 616 |
-
import transformers
|
| 617 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 618 |
-
from sklearn.manifold import TSNE
|
| 619 |
-
from copy import deepcopy, copy
|
| 620 |
-
import seaborn as sns
|
| 621 |
-
import matplotlib.pylab as plt
|
| 622 |
-
from pprint import pprint
|
| 623 |
-
import shutil
|
| 624 |
-
import datetime
|
| 625 |
-
import re
|
| 626 |
-
import json
|
| 627 |
-
from pathlib import Path
|
| 628 |
-
|
| 629 |
-
import torch
|
| 630 |
-
import torch.nn as nn
|
| 631 |
-
from torch.utils.data import Dataset, DataLoader
|
| 632 |
-
import unicodedata
|
| 633 |
-
import re
|
| 634 |
-
|
| 635 |
-
import torch
|
| 636 |
-
import torch.nn as nn
|
| 637 |
-
from torch.utils.data import Dataset, DataLoader
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
# Pre-trained model
|
| 642 |
-
class Encoder(nn.Module):
|
| 643 |
-
def __init__(self, layers, freeze_bert, model):
|
| 644 |
-
super(Encoder, self).__init__()
|
| 645 |
-
|
| 646 |
-
# Dummy Parameter
|
| 647 |
-
self.dummy_param = nn.Parameter(torch.empty(0))
|
| 648 |
-
|
| 649 |
-
# Pre-trained model
|
| 650 |
-
self.model = deepcopy(model)
|
| 651 |
-
|
| 652 |
-
# Freezing bert parameters
|
| 653 |
-
if freeze_bert:
|
| 654 |
-
for param in self.model.parameters():
|
| 655 |
-
param.requires_grad = freeze_bert
|
| 656 |
-
|
| 657 |
-
# Selecting hidden layers of the pre-trained model
|
| 658 |
-
old_model_encoder = self.model.encoder.layer
|
| 659 |
-
new_model_encoder = nn.ModuleList()
|
| 660 |
-
|
| 661 |
-
for i in layers:
|
| 662 |
-
new_model_encoder.append(old_model_encoder[i])
|
| 663 |
-
|
| 664 |
-
self.model.encoder.layer = new_model_encoder
|
| 665 |
|
| 666 |
-
# Feed forward
|
| 667 |
-
def forward(self, **x):
|
| 668 |
-
return self.model(**x)['pooler_output']
|
| 669 |
-
|
| 670 |
-
# Complete model
|
| 671 |
-
class SLR_Classifier(nn.Module):
|
| 672 |
-
def __init__(self, **data):
|
| 673 |
-
super(SLR_Classifier, self).__init__()
|
| 674 |
-
|
| 675 |
-
# Dummy Parameter
|
| 676 |
-
self.dummy_param = nn.Parameter(torch.empty(0))
|
| 677 |
-
|
| 678 |
-
# Loss function
|
| 679 |
-
# Binary Cross Entropy with logits reduced to mean
|
| 680 |
-
self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
|
| 681 |
-
pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
|
| 682 |
-
|
| 683 |
-
# Pre-trained model
|
| 684 |
-
self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
|
| 685 |
-
freeze_bert = data.get("freeze_bert", False),
|
| 686 |
-
model = data.get("model"),
|
| 687 |
-
)
|
| 688 |
-
|
| 689 |
-
# Feature Map Layer
|
| 690 |
-
self.feature_map = nn.Sequential(
|
| 691 |
-
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
| 692 |
-
nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
| 693 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
| 694 |
-
nn.Linear(self.Encoder.model.config.hidden_size, 200),
|
| 695 |
-
nn.Dropout(data.get("drop", 0.5)),
|
| 696 |
-
)
|
| 697 |
-
|
| 698 |
-
# Classifier Layer
|
| 699 |
-
self.classifier = nn.Sequential(
|
| 700 |
-
# nn.LayerNorm(self.Encoder.model.config.hidden_size),
|
| 701 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
| 702 |
-
# nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
|
| 703 |
-
# nn.Dropout(data.get("drop", 0.5)),
|
| 704 |
-
nn.Tanh(),
|
| 705 |
-
nn.Linear(200, 1)
|
| 706 |
-
)
|
| 707 |
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
# Feed forward
|
| 713 |
-
def forward(self, input_ids, attention_mask, token_type_ids, labels):
|
| 714 |
-
|
| 715 |
-
predict = self.Encoder(**{"input_ids":input_ids,
|
| 716 |
-
"attention_mask":attention_mask,
|
| 717 |
-
"token_type_ids":token_type_ids})
|
| 718 |
-
feature = self.feature_map(predict)
|
| 719 |
-
logit = self.classifier(feature)
|
| 720 |
-
|
| 721 |
-
predict = torch.sigmoid(logit)
|
| 722 |
|
| 723 |
-
#
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
return [loss, [feature, logit], predict]
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
# Undesirable patterns within texts
|
| 730 |
-
patterns = {
|
| 731 |
-
'CONCLUSIONS AND IMPLICATIONS':'',
|
| 732 |
-
'BACKGROUND AND PURPOSE':'',
|
| 733 |
-
'EXPERIMENTAL APPROACH':'',
|
| 734 |
-
'KEY RESULTS AEA':'',
|
| 735 |
-
'©':'',
|
| 736 |
-
'®':'',
|
| 737 |
-
'μ':'',
|
| 738 |
-
'(C)':'',
|
| 739 |
-
'OBJECTIVE:':'',
|
| 740 |
-
'MATERIALS AND METHODS:':'',
|
| 741 |
-
'SIGNIFICANCE:':'',
|
| 742 |
-
'BACKGROUND:':'',
|
| 743 |
-
'RESULTS:':'',
|
| 744 |
-
'METHODS:':'',
|
| 745 |
-
'CONCLUSIONS:':'',
|
| 746 |
-
'AIM:':'',
|
| 747 |
-
'STUDY DESIGN:':'',
|
| 748 |
-
'CLINICAL RELEVANCE:':'',
|
| 749 |
-
'CONCLUSION:':'',
|
| 750 |
-
'HYPOTHESIS:':'',
|
| 751 |
-
'CLINICAL RELEVANCE:':'',
|
| 752 |
-
'Questions/Purposes:':'',
|
| 753 |
-
'Introduction:':'',
|
| 754 |
-
'PURPOSE:':'',
|
| 755 |
-
'PATIENTS AND METHODS:':'',
|
| 756 |
-
'FINDINGS:':'',
|
| 757 |
-
'INTERPRETATIONS:':'',
|
| 758 |
-
'FUNDING:':'',
|
| 759 |
-
'PROGRESS:':'',
|
| 760 |
-
'CONTEXT:':'',
|
| 761 |
-
'MEASURES:':'',
|
| 762 |
-
'DESIGN:':'',
|
| 763 |
-
'BACKGROUND AND OBJECTIVES:':'',
|
| 764 |
-
'<p>':'',
|
| 765 |
-
'</p>':'',
|
| 766 |
-
'<<ETX>>':'',
|
| 767 |
-
'+/-':'',
|
| 768 |
-
}
|
| 769 |
-
|
| 770 |
-
patterns = {x.lower():y for x,y in patterns.items()}
|
| 771 |
-
|
| 772 |
-
LABEL_MAP = {'negative': 0,
|
| 773 |
-
'not included':0,
|
| 774 |
-
'0':0,
|
| 775 |
-
0:0,
|
| 776 |
-
'excluded':0,
|
| 777 |
-
'positive': 1,
|
| 778 |
-
'included':1,
|
| 779 |
-
'1':1,
|
| 780 |
-
1:1,
|
| 781 |
-
}
|
| 782 |
-
|
| 783 |
-
class SLR_DataSet(Dataset):
|
| 784 |
-
def __init__(self, **args):
|
| 785 |
-
self.tokenizer = args.get('tokenizer')
|
| 786 |
-
self.data = args.get('data')
|
| 787 |
-
self.max_seq_length = args.get("max_seq_length", 512)
|
| 788 |
-
self.INPUT_NAME = args.get("input", 'x')
|
| 789 |
-
self.LABEL_NAME = args.get("output", 'y')
|
| 790 |
-
|
| 791 |
-
# Tokenizing and processing text
|
| 792 |
-
def encode_text(self, example):
|
| 793 |
-
comment_text = example[self.INPUT_NAME]
|
| 794 |
-
comment_text = self.treat_text(comment_text)
|
| 795 |
-
|
| 796 |
-
try:
|
| 797 |
-
labels = LABEL_MAP[example[self.LABEL_NAME].lower()]
|
| 798 |
-
except:
|
| 799 |
-
labels = -1
|
| 800 |
-
|
| 801 |
-
encoding = self.tokenizer.encode_plus(
|
| 802 |
-
(comment_text, "It is great text"),
|
| 803 |
-
add_special_tokens=True,
|
| 804 |
-
max_length=self.max_seq_length,
|
| 805 |
-
return_token_type_ids=True,
|
| 806 |
-
padding="max_length",
|
| 807 |
-
truncation=True,
|
| 808 |
-
return_attention_mask=True,
|
| 809 |
-
return_tensors='pt',
|
| 810 |
-
)
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
return tuple((
|
| 814 |
-
encoding["input_ids"].flatten(),
|
| 815 |
-
encoding["attention_mask"].flatten(),
|
| 816 |
-
encoding["token_type_ids"].flatten(),
|
| 817 |
-
torch.tensor([torch.tensor(labels).to(int)])
|
| 818 |
-
))
|
| 819 |
-
|
| 820 |
-
# Text processing function
|
| 821 |
-
def treat_text(self, text):
|
| 822 |
-
text = unicodedata.normalize("NFKD",str(text))
|
| 823 |
-
text = multiple_replace(patterns,text.lower())
|
| 824 |
-
text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
|
| 825 |
-
text = re.sub('( +)',' ', text)
|
| 826 |
-
text = re.sub('(, ,)|(,,)',',', text)
|
| 827 |
-
text = re.sub('(%)|(per cent)',' percent', text)
|
| 828 |
-
return text
|
| 829 |
-
|
| 830 |
-
def __len__(self):
|
| 831 |
-
return len(self.data)
|
| 832 |
-
|
| 833 |
-
# Returning data
|
| 834 |
-
def __getitem__(self, index: int):
|
| 835 |
-
# print(index)
|
| 836 |
-
data_row = self.data.reset_index().iloc[index]
|
| 837 |
-
temp_data = self.encode_text(data_row)
|
| 838 |
-
return temp_data
|
| 839 |
-
|
| 840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
|
| 842 |
-
#
|
| 843 |
-
|
|
|
|
|
|
|
|
|
|
| 844 |
|
| 845 |
-
|
| 846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
-
|
| 849 |
-
|
|
|
|
|
|
|
|
|
|
| 850 |
|
| 851 |
-
# Undesirable patterns within texts
|
| 852 |
|
| 853 |
|
|
|
|
| 1 |
+
from ML_SLRC import *
|
| 2 |
+
|
| 3 |
import os
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from torch.optim import Adam
|
| 10 |
+
|
| 11 |
+
import gc
|
| 12 |
+
from torchmetrics import functional as fn
|
| 13 |
+
|
| 14 |
import random
|
|
|
|
| 15 |
|
| 16 |
+
|
| 17 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
| 18 |
+
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
from sklearn.metrics import confusion_matrix
|
| 22 |
+
from sklearn.metrics import roc_curve, auc
|
| 23 |
+
import ipywidgets as widgets
|
| 24 |
+
from IPython.display import display, clear_output
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import warnings
|
| 27 |
import torch
|
| 28 |
+
|
|
|
|
| 29 |
import time
|
|
|
|
|
|
|
| 30 |
from sklearn.manifold import TSNE
|
| 31 |
+
from copy import deepcopy
|
| 32 |
import seaborn as sns
|
| 33 |
import matplotlib.pylab as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
import json
|
| 35 |
from pathlib import Path
|
| 36 |
+
|
| 37 |
+
import re
|
| 38 |
+
from collections import defaultdict
|
| 39 |
+
|
| 40 |
+
# SEED = 2222
|
| 41 |
+
|
| 42 |
+
# gen_seed = torch.Generator().manual_seed(SEED)
|
| 43 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
|
|
|
| 46 |
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
# Random seed function
|
|
|
|
| 54 |
np.random.seed(value)
|
| 55 |
random.seed(value)
|
| 56 |
|
| 57 |
+
# Tasks for meta-learner
|
| 58 |
def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
|
| 59 |
idxs = list(range(0,len(taskset)))
|
| 60 |
if is_shuffle:
|
|
|
|
| 63 |
yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
|
| 64 |
|
| 65 |
|
| 66 |
+
# Prepare data to process by Domain-learner
|
| 67 |
+
def prepare_data(data, batch_size, tokenizer,max_seq_length,
|
| 68 |
input = 'text', output = 'label',
|
| 69 |
+
train_size_per_class = 5, global_datasets = False,
|
| 70 |
+
treat_text_fun =None):
|
| 71 |
data = data.reset_index().drop("index", axis=1)
|
| 72 |
|
| 73 |
+
if global_datasets:
|
| 74 |
+
global data_train, data_test
|
|
|
|
| 75 |
|
| 76 |
+
# Sample task for training
|
| 77 |
+
data_train = data.groupby('label').sample(train_size_per_class, replace=False)
|
| 78 |
+
idex = data.index.isin(data_train.index)
|
| 79 |
|
| 80 |
+
# The Test set to label by the model
|
| 81 |
+
data_test = data[~idex].reset_index()
|
| 82 |
|
| 83 |
|
| 84 |
+
# Transform in dataset to model
|
| 85 |
+
## Train
|
| 86 |
dataset_train = SLR_DataSet(
|
| 87 |
data = data_train.sample(frac=1),
|
| 88 |
input = input,
|
| 89 |
output = output,
|
| 90 |
tokenizer=tokenizer,
|
| 91 |
+
max_seq_length =max_seq_length,
|
| 92 |
+
treat_text =treat_text_fun)
|
| 93 |
|
| 94 |
+
## Test
|
|
|
|
|
|
|
| 95 |
dataset_test = SLR_DataSet(
|
| 96 |
data = data_test,
|
| 97 |
input = input,
|
| 98 |
output = output,
|
| 99 |
tokenizer=tokenizer,
|
| 100 |
+
max_seq_length =max_seq_length,
|
| 101 |
+
treat_text =treat_text_fun)
|
| 102 |
|
| 103 |
# Dataloaders
|
| 104 |
+
## Train
|
| 105 |
data_train_loader = DataLoader(dataset_train,
|
| 106 |
shuffle=True,
|
| 107 |
batch_size=batch_size['train']
|
| 108 |
)
|
| 109 |
|
| 110 |
+
## Test
|
| 111 |
if len(dataset_test) % batch_size['test'] == 1 :
|
| 112 |
data_test_loader = DataLoader(dataset_test,
|
| 113 |
batch_size=batch_size['test'],
|
|
|
|
| 120 |
return data_train_loader, data_test_loader, data_train, data_test
|
| 121 |
|
| 122 |
|
| 123 |
+
# Meta trainer
|
| 124 |
+
def meta_train(data, model, device, Info,
|
| 125 |
+
print_epoch =True,
|
| 126 |
+
Test_resource =None,
|
| 127 |
+
treat_text_fun =None):
|
| 128 |
|
| 129 |
+
# Meta-learner model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
learner = Learner(model = model, device = device, **Info)
|
| 131 |
|
| 132 |
# Testing tasks
|
| 133 |
if isinstance(Test_resource, pd.DataFrame):
|
| 134 |
test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
|
| 135 |
+
training=False,treat_text =treat_text_fun, **Info)
|
| 136 |
|
| 137 |
|
| 138 |
torch.clear_autocast_cache()
|
| 139 |
gc.collect()
|
| 140 |
torch.cuda.empty_cache()
|
| 141 |
|
| 142 |
+
# Meta epoch (Outer epoch)
|
| 143 |
for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
|
|
|
|
| 144 |
|
| 145 |
+
# Train tasks
|
| 146 |
train = MetaTask(data,
|
| 147 |
num_task = Info['num_task_train'],
|
| 148 |
k_support=Info['k_qry'],
|
| 149 |
+
k_query=Info['k_spt'],
|
| 150 |
+
treat_text =treat_text_fun, **Info)
|
| 151 |
|
| 152 |
+
# Batch of train tasks
|
| 153 |
db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
|
| 154 |
|
| 155 |
if print_epoch:
|
| 156 |
# Outer loop bach training
|
| 157 |
for step, task_batch in enumerate(db):
|
| 158 |
print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
|
| 159 |
+
|
| 160 |
+
# meta-feedfoward (outer-feedfoward)
|
| 161 |
acc = learner(task_batch, valid_train= print_epoch)
|
| 162 |
print('Step:', step, '\ttraining Acc:', acc)
|
| 163 |
+
|
| 164 |
if isinstance(Test_resource, pd.DataFrame):
|
| 165 |
+
# Validating Model
|
| 166 |
if ((epoch+1) % 4) + step == 0:
|
| 167 |
random_seed(123)
|
| 168 |
print("\n-----------------Testing Mode-----------------\n")
|
| 169 |
+
|
| 170 |
+
# Batch of test tasks
|
| 171 |
db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
|
| 172 |
acc_all_test = []
|
| 173 |
|
|
|
|
| 181 |
|
| 182 |
# Restarting training randomly
|
| 183 |
random_seed(int(time.time() % 10))
|
| 184 |
+
|
|
|
|
| 185 |
else:
|
| 186 |
for step, task_batch in enumerate(db):
|
| 187 |
+
# meta-feedfoward (outer-feedfoward)
|
| 188 |
acc = learner(task_batch, print_epoch, valid_train= print_epoch)
|
| 189 |
|
| 190 |
torch.clear_autocast_cache()
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
|
| 197 |
+
# Start the model's parameters
|
| 198 |
model_meta = deepcopy(model)
|
| 199 |
optimizer = Adam(model_meta.parameters(), lr=lr)
|
| 200 |
|
| 201 |
model_meta.to(device)
|
| 202 |
model_meta.train()
|
| 203 |
|
| 204 |
+
# Task epoch (Inner epoch)
|
| 205 |
for i in range(0, epoch):
|
| 206 |
all_loss = []
|
| 207 |
|
|
|
|
| 210 |
batch = tuple(t.to(device) for t in batch)
|
| 211 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
| 212 |
|
| 213 |
+
# Inner Feedfoward
|
| 214 |
loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
| 215 |
|
| 216 |
+
# compute grads
|
| 217 |
loss.backward()
|
| 218 |
|
| 219 |
+
# update parameters
|
| 220 |
optimizer.step()
|
| 221 |
optimizer.zero_grad()
|
| 222 |
|
|
|
|
| 227 |
print("Loss: ", np.mean(all_loss))
|
| 228 |
|
| 229 |
|
| 230 |
+
# Test evaluation
|
| 231 |
model_meta.eval()
|
| 232 |
all_loss = []
|
| 233 |
+
all_acc = []
|
| 234 |
features = []
|
| 235 |
labels = []
|
| 236 |
predi_logit = []
|
| 237 |
|
| 238 |
with torch.no_grad():
|
| 239 |
+
# Test's Batch loop
|
| 240 |
for inner_step, batch in enumerate(tqdm(data_test_loader,
|
| 241 |
desc="Test validation | " + name,
|
| 242 |
ncols=80)) :
|
| 243 |
batch = tuple(t.to(device) for t in batch)
|
| 244 |
input_ids, attention_mask,q_token_type_ids, label_id = batch
|
| 245 |
|
| 246 |
+
# Predictions
|
| 247 |
_, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
|
| 248 |
|
| 249 |
+
# Save batch's predictions
|
| 250 |
prediction = prediction.detach().cpu().squeeze()
|
| 251 |
label_id = label_id.detach().cpu()
|
| 252 |
+
labels.append(label_id.numpy().squeeze())
|
| 253 |
+
|
| 254 |
logit = feature[1].detach().cpu()
|
| 255 |
+
predi_logit.append(logit.numpy())
|
| 256 |
|
| 257 |
+
feature_lat = feature[0].detach().cpu()
|
| 258 |
features.append(feature_lat.numpy())
|
|
|
|
| 259 |
|
| 260 |
+
# Accuracy over the test's bach
|
| 261 |
+
acc = fn.accuracy(prediction, label_id).item()
|
| 262 |
+
all_acc.append(acc)
|
| 263 |
del input_ids, attention_mask, label_id, batch
|
| 264 |
|
| 265 |
+
if print_info:
|
| 266 |
+
print("acc:", np.mean(all_acc))
|
| 267 |
|
| 268 |
model_meta.to('cpu')
|
| 269 |
gc.collect()
|
|
|
|
| 271 |
|
| 272 |
del model_meta, optimizer
|
| 273 |
|
| 274 |
+
return map_feature_tsne(features, labels, predi_logit)
|
| 275 |
|
| 276 |
+
# Process predictions and map the feature_map in tsne
|
| 277 |
+
def map_feature_tsne(features, labels, predi_logit):
|
| 278 |
+
|
| 279 |
features = np.concatenate(np.array(features,dtype=object))
|
|
|
|
|
|
|
|
|
|
| 280 |
features = torch.tensor(features.astype(np.float32)).detach().clone()
|
| 281 |
+
|
| 282 |
+
labels = np.concatenate(np.array(labels,dtype=object))
|
| 283 |
labels = torch.tensor(labels.astype(int)).detach().clone()
|
| 284 |
+
|
| 285 |
+
logits = np.concatenate(np.array(predi_logit,dtype=object))
|
| 286 |
logits = torch.tensor(logits.astype(np.float32)).detach().clone()
|
| 287 |
|
| 288 |
+
# Dimention reduction
|
| 289 |
X_embedded = TSNE(n_components=2, learning_rate='auto',
|
| 290 |
init='random').fit_transform(features.detach().clone())
|
| 291 |
|
| 292 |
return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
|
| 293 |
+
|
|
|
|
| 294 |
def wss_calc(logit, labels, trsh = 0.5):
|
| 295 |
|
| 296 |
+
# Prediction label given the threshold
|
| 297 |
predict_trash = torch.sigmoid(logit).squeeze() >= trsh
|
| 298 |
+
|
| 299 |
+
# Compute confusion matrix values
|
| 300 |
CM = confusion_matrix(labels, predict_trash.to(int) )
|
| 301 |
tn, fp, fne, tp = CM.ravel()
|
| 302 |
|
|
|
|
| 304 |
N = (tn + fp)
|
| 305 |
recall = tp/(tp+fne)
|
| 306 |
|
| 307 |
+
# WSS
|
| 308 |
+
wss = (tn + fne)/len(labels) -(1- recall)
|
| 309 |
|
| 310 |
+
# AWSS
|
| 311 |
+
awss = (tn/N - fne/P)
|
| 312 |
|
| 313 |
return {
|
| 314 |
+
"wss": round(wss,4),
|
| 315 |
+
"awss": round(awss,4),
|
| 316 |
"R": round(recall,4),
|
| 317 |
"CM": CM
|
| 318 |
}
|
| 319 |
|
| 320 |
|
| 321 |
+
# Compute the metrics
|
| 322 |
+
def plot(logits, X_embedded, labels, threshold, show = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
namefig = "plot", make_plot = True, print_stats = True, save = True):
|
| 324 |
col = pd.MultiIndex.from_tuples([
|
| 325 |
("Predict", "0"),
|
|
|
|
| 332 |
|
| 333 |
predict = torch.sigmoid(logits).detach().clone()
|
| 334 |
|
| 335 |
+
# Roc curve
|
|
|
|
| 336 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
| 337 |
|
| 338 |
+
# Given by a Recall of 95% (threshold avaliation)
|
| 339 |
+
## WSS
|
| 340 |
+
### Index to recall
|
| 341 |
idx_wss95 = sum(tpr < 0.95)
|
| 342 |
+
### threshold
|
| 343 |
thresholds95 = thresholds[idx_wss95]
|
| 344 |
|
| 345 |
+
### Compute the metrics
|
| 346 |
wss95_info = wss_calc(logits,labels, thresholds95 )
|
| 347 |
acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
|
| 348 |
f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
|
| 349 |
|
| 350 |
|
| 351 |
+
# Given by a threshold (recall avaliation)
|
| 352 |
+
### Compute the metrics
|
| 353 |
+
wss_info = wss_calc(logits,labels, threshold )
|
| 354 |
+
acc_wssR = fn.accuracy(predict, labels, threshold=threshold)
|
| 355 |
+
f1_wssR = fn.f1_score(predict, labels, threshold=threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
metrics= {
|
|
|
|
| 370 |
# f1
|
| 371 |
"f1@95": f1_wss95.item(),
|
| 372 |
"f1@R": f1_wssR.item(),
|
| 373 |
+
# threshold 95
|
| 374 |
+
"threshold@95": thresholds95
|
| 375 |
}
|
| 376 |
|
| 377 |
+
# Print stats
|
|
|
|
| 378 |
if print_stats:
|
| 379 |
wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
|
| 380 |
wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
|
|
|
|
| 382 |
print(wss95_adj)
|
| 383 |
print('Acc.:', round(acc_wss95.item(), 4))
|
| 384 |
print('F1-score:', round(f1_wss95.item(), 4))
|
| 385 |
+
print(f"threshold to wss95: {round(thresholds95, 4)}")
|
| 386 |
cm = pd.DataFrame(wss95_info['CM'],
|
| 387 |
index=index,
|
| 388 |
columns=col)
|
| 389 |
|
| 390 |
print("\nConfusion matrix:")
|
| 391 |
print(cm)
|
| 392 |
+
print("\n---Metrics with threshold:", threshold, "----\n")
|
| 393 |
wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
|
| 394 |
print(wss)
|
| 395 |
wss_adj= f"AWSS@R:{wss_info['awss']}"
|
|
|
|
| 404 |
print(cm)
|
| 405 |
|
| 406 |
|
| 407 |
+
# Plots
|
| 408 |
|
| 409 |
if make_plot:
|
| 410 |
|
| 411 |
fig, axes = plt.subplots(1, 4, figsize=(25,10))
|
| 412 |
alpha = torch.squeeze(predict).numpy()
|
| 413 |
|
| 414 |
+
# TSNE
|
|
|
|
| 415 |
p1 = sns.scatterplot(x=X_embedded[:, 0],
|
| 416 |
y=X_embedded[:, 1],
|
| 417 |
hue=labels,
|
| 418 |
+
alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20)
|
| 419 |
|
| 420 |
+
|
| 421 |
+
# WSS@95
|
| 422 |
t_wss = predict >= thresholds95
|
| 423 |
t_wss = t_wss.squeeze().numpy()
|
|
|
|
| 424 |
p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
|
| 425 |
y=X_embedded[t_wss, 1],
|
| 426 |
hue=labels[t_wss],
|
| 427 |
+
alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20)
|
| 428 |
|
| 429 |
+
# WSS@R
|
| 430 |
+
t = predict >= threshold
|
| 431 |
t = t.squeeze().numpy()
|
|
|
|
| 432 |
p3 = sns.scatterplot(x=X_embedded[t, 0],
|
| 433 |
y=X_embedded[t, 1],
|
| 434 |
hue=labels[t],
|
| 435 |
+
alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20)
|
|
|
|
| 436 |
|
| 437 |
+
# ROC-Curve
|
| 438 |
roc_auc = auc(fpr, tpr)
|
| 439 |
lw = 2
|
|
|
|
| 440 |
axes[3].plot(
|
| 441 |
fpr,
|
| 442 |
tpr,
|
| 443 |
color="darkorange",
|
| 444 |
lw=lw,
|
| 445 |
label="ROC curve (area = %0.2f)" % roc_auc)
|
|
|
|
| 446 |
axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
|
| 447 |
axes[3].axhline(y=0.95, color='r', linestyle='-')
|
| 448 |
+
# axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate")
|
| 449 |
axes[3].legend(loc="lower right")
|
| 450 |
+
axes[3].set_title(label= "ROC", size = 20)
|
| 451 |
+
axes[3].set_ylabel("True Positive Rate", fontsize = 15)
|
| 452 |
+
axes[3].set_xlabel("False Positive Rate", fontsize = 15)
|
| 453 |
+
|
| 454 |
|
| 455 |
if show:
|
| 456 |
plt.show()
|
|
|
|
| 460 |
|
| 461 |
return metrics
|
| 462 |
|
| 463 |
+
|
| 464 |
def auc_plot(logits,labels, color = "darkorange", label = "test"):
|
| 465 |
predict = torch.sigmoid(logits).detach().clone()
|
| 466 |
fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
|
|
|
|
| 480 |
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
| 481 |
plt.axhline(y=0.95, color='r', linestyle='-')
|
| 482 |
|
| 483 |
+
# Interface to evaluation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
class diagnosis():
|
| 485 |
+
def __init__(self, names, Valid_resource, batch_size_test,
|
| 486 |
+
model,Info, device,treat_text_fun=None,start = 0):
|
| 487 |
self.names=names
|
| 488 |
self.Valid_resource=Valid_resource
|
| 489 |
self.batch_size_test=batch_size_test
|
| 490 |
self.model=model
|
| 491 |
+
self.start=start
|
| 492 |
+
self.Info = Info
|
| 493 |
+
self.device = device
|
| 494 |
+
self.treat_text_fun = treat_text_fun
|
| 495 |
+
|
| 496 |
|
| 497 |
+
# BOX INPUT
|
| 498 |
self.value_trash = widgets.FloatText(
|
| 499 |
value=0.95,
|
| 500 |
+
description='threshold',
|
| 501 |
disabled=False
|
| 502 |
)
|
|
|
|
| 503 |
self.valueb = widgets.IntText(
|
| 504 |
value=10,
|
| 505 |
description='size',
|
| 506 |
disabled=False
|
| 507 |
)
|
| 508 |
|
| 509 |
+
# Buttons
|
| 510 |
self.train_b = widgets.Button(description="Train")
|
| 511 |
self.next_b = widgets.Button(description="Next")
|
| 512 |
self.eval_b = widgets.Button(description="Evaluation")
|
| 513 |
|
| 514 |
self.hbox = widgets.HBox([self.train_b, self.valueb])
|
| 515 |
|
| 516 |
+
# Click buttons functions
|
| 517 |
self.next_b.on_click(self.Next_button)
|
| 518 |
self.train_b.on_click(self.Train_button)
|
| 519 |
self.eval_b.on_click(self.Evaluation_button)
|
|
|
|
| 524 |
clear_output()
|
| 525 |
self.i=self.i+1
|
| 526 |
|
| 527 |
+
# Select the domain data
|
| 528 |
+
self.domain = self.names[self.i]
|
|
|
|
|
|
|
|
|
|
| 529 |
self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
|
| 530 |
+
|
| 531 |
+
print("Name:", self.domain)
|
| 532 |
print(self.data['label'].value_counts())
|
|
|
|
| 533 |
display(self.hbox)
|
| 534 |
display(self.next_b)
|
| 535 |
|
| 536 |
+
|
| 537 |
# Train button
|
| 538 |
def Train_button(self, y):
|
| 539 |
clear_output()
|
| 540 |
print(self.domain)
|
| 541 |
|
| 542 |
+
# Prepare data for training (domain-learner)
|
| 543 |
self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
|
| 544 |
train_size_per_class = self.valueb.value,
|
| 545 |
+
batch_size = {'train': self.Info['inner_batch_size'],
|
| 546 |
+
'test': self.batch_size_test},
|
| 547 |
+
max_seq_length = self.Info['max_seq_length'],
|
| 548 |
+
tokenizer = self.Info['tokenizer'],
|
| 549 |
input = "text",
|
| 550 |
+
output = "label",
|
| 551 |
+
treat_text_fun=self.treat_text_fun)
|
| 552 |
|
| 553 |
+
# Train the model and predict in the test set
|
| 554 |
self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
|
| 555 |
+
self.model, self.device,
|
| 556 |
+
epoch = self.Info['inner_update_step'],
|
| 557 |
+
lr=self.Info['inner_update_lr'],
|
| 558 |
print_info=True,
|
| 559 |
name = self.domain)
|
| 560 |
|
|
|
|
| 563 |
display(tresh_box)
|
| 564 |
display(self.next_b)
|
| 565 |
|
| 566 |
+
|
| 567 |
# Evaluation button
|
| 568 |
def Evaluation_button(self, te):
|
| 569 |
clear_output()
|
|
|
|
| 572 |
print(self.domain)
|
| 573 |
# print("\n")
|
| 574 |
print("-------Train data-------")
|
| 575 |
+
print(data_train['label'].value_counts())
|
| 576 |
print("-------Test data-------")
|
| 577 |
+
print(data_test['label'].value_counts())
|
| 578 |
# print("\n")
|
| 579 |
|
| 580 |
display(self.next_b)
|
| 581 |
display(tresh_box)
|
| 582 |
display(self.hbox)
|
| 583 |
|
| 584 |
+
# Compute metrics
|
| 585 |
metrics = plot(self.logits, self.X_embedded, self.labels,
|
| 586 |
+
threshold=self.Info['threshold'], show = True,
|
|
|
|
| 587 |
namefig= 'test',
|
| 588 |
make_plot = True,
|
| 589 |
print_stats = True,
|
|
|
|
| 591 |
|
| 592 |
def __call__(self):
|
| 593 |
self.i= self.start-1
|
|
|
|
| 594 |
clear_output()
|
| 595 |
display(self.next_b)
|
| 596 |
|
| 597 |
|
| 598 |
|
| 599 |
|
| 600 |
+
# Simulation attemps of domain learner
|
| 601 |
+
def pipeline_simulation(Valid_resource, names_to_valid, path_save,
|
| 602 |
+
model, Info, device, initializer_model,
|
| 603 |
+
treat_text_fun=None):
|
| 604 |
+
n_attempt = 5
|
| 605 |
+
batch_test = 100
|
| 606 |
|
| 607 |
+
# Create a directory to save informations
|
| 608 |
+
for name in names_to_valid:
|
| 609 |
+
name = re.sub("\.csv", "",name)
|
| 610 |
+
Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True)
|
| 611 |
|
| 612 |
+
# Dict to sabe roc curves
|
| 613 |
+
roc_stats = defaultdict(lambda: defaultdict(
|
| 614 |
+
lambda: defaultdict(
|
| 615 |
+
list
|
| 616 |
+
)
|
| 617 |
+
)
|
| 618 |
+
)
|
| 619 |
|
| 620 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
+
all_metrics = []
|
| 624 |
+
# Loop over a list of domains
|
| 625 |
+
for name in names_to_valid:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
|
| 627 |
+
# Select a domain dataset
|
| 628 |
+
data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 629 |
|
| 630 |
+
# Attempts simulation
|
| 631 |
+
for attempt in range(n_attempt):
|
| 632 |
+
print("---"*4,"attempt", attempt, "---"*4)
|
| 633 |
+
|
| 634 |
+
# Prepare data to pass to the model
|
| 635 |
+
data_train_loader, data_test_loader, _ , _ = prepare_data(data,
|
| 636 |
+
train_size_per_class = Info['k_spt'],
|
| 637 |
+
batch_size = {'train': Info['inner_batch_size'],
|
| 638 |
+
'test': batch_test},
|
| 639 |
+
max_seq_length = Info['max_seq_length'],
|
| 640 |
+
tokenizer = Info['tokenizer'],
|
| 641 |
+
input = "text",
|
| 642 |
+
output = "label",
|
| 643 |
+
treat_text_fun=treat_text_fun)
|
| 644 |
+
|
| 645 |
+
# Train the model and evaluate on the test set of the domain
|
| 646 |
+
logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader,
|
| 647 |
+
model, device,
|
| 648 |
+
epoch = Info['inner_update_step'],
|
| 649 |
+
lr=Info['inner_update_lr'],
|
| 650 |
+
print_info=False,
|
| 651 |
+
name = name)
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
name_domain = re.sub("\.csv", "",name)
|
| 655 |
|
| 656 |
+
# Compute the metrics
|
| 657 |
+
metrics = plot(logits, X_embedded, labels,
|
| 658 |
+
threshold=Info['threshold'], show = False,
|
| 659 |
+
namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots',
|
| 660 |
+
make_plot = True, print_stats = False, save = True)
|
| 661 |
|
| 662 |
+
# Compute the roc-curve
|
| 663 |
+
fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze())
|
| 664 |
+
|
| 665 |
+
# Save the correspoud information of the domain
|
| 666 |
+
metrics['name'] = name_domain
|
| 667 |
+
metrics['layer_size'] = Info['bert_layers']
|
| 668 |
+
metrics['attempt'] = attempt
|
| 669 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist())
|
| 670 |
+
roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist())
|
| 671 |
+
all_metrics.append(metrics)
|
| 672 |
+
|
| 673 |
+
# Save the metrics and the roc curve of the attemp
|
| 674 |
+
pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv")
|
| 675 |
+
roc_path = path_save + "roc_stats.json"
|
| 676 |
+
with open(roc_path, 'w') as fp:
|
| 677 |
+
json.dump(roc_stats, fp)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
del fpr, tpr, logits, X_embedded, labels
|
| 681 |
+
del features, metrics, _
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# Save the information used to evaluate the validation resource
|
| 685 |
+
save_info = Info.copy()
|
| 686 |
+
save_info['model'] = initializer_model.tokenizer.name_or_path
|
| 687 |
+
save_info.pop("tokenizer")
|
| 688 |
+
save_info.pop("bert_layers")
|
| 689 |
+
|
| 690 |
+
info_path = path_save+"info.json"
|
| 691 |
+
with open(info_path, 'w') as fp:
|
| 692 |
+
json.dump(save_info, fp)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
# Loading dataset statistics
|
| 696 |
+
def load_data_statistics(paths, names):
|
| 697 |
+
size = []
|
| 698 |
+
pos = []
|
| 699 |
+
neg = []
|
| 700 |
+
for p in paths:
|
| 701 |
+
data = pd.read_csv(p)
|
| 702 |
+
data = data.dropna()
|
| 703 |
+
# Dataset size
|
| 704 |
+
size.append(len(data))
|
| 705 |
+
# Number of positive labels
|
| 706 |
+
pos.append(data['labels'].value_counts()[1])
|
| 707 |
+
# Number of negative labels
|
| 708 |
+
neg.append(data['labels'].value_counts()[0])
|
| 709 |
+
del data
|
| 710 |
+
|
| 711 |
+
info_load = pd.DataFrame({
|
| 712 |
+
"size":size,
|
| 713 |
+
"pos":pos,
|
| 714 |
+
"neg":neg,
|
| 715 |
+
"names":names,
|
| 716 |
+
"paths": paths })
|
| 717 |
+
return info_load
|
| 718 |
+
|
| 719 |
+
# Loading the datasets
|
| 720 |
+
def load_data(train_info_load):
|
| 721 |
+
|
| 722 |
+
col = ['abstract','title', 'labels', 'domain']
|
| 723 |
+
|
| 724 |
+
data_train = pd.DataFrame(columns=col)
|
| 725 |
+
for p in train_info_load['paths']:
|
| 726 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
| 727 |
+
data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']]
|
| 728 |
+
data_temp['domain'] = os.path.basename(p)
|
| 729 |
+
data_train = pd.concat([data_train, data_temp])
|
| 730 |
+
|
| 731 |
+
data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '')
|
| 732 |
|
| 733 |
+
return( data_train \
|
| 734 |
+
.replace({"labels":{0:"negative", 1:'positive'}})\
|
| 735 |
+
.rename({"labels":"label"} , axis=1)\
|
| 736 |
+
.loc[ :,("text","domain","label")]
|
| 737 |
+
)
|
| 738 |
|
|
|
|
| 739 |
|
| 740 |
|