Ssddsasd commited on
Commit
9331647
·
1 Parent(s): 859689b

Upload Util_funs.py

Browse files
Files changed (1) hide show
  1. Util_funs.py +599 -0
Util_funs.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import json, pickle
6
+ # from ML_SLRC import SLR_DataSet, SLR_Classifier
7
+
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ import math
11
+ import torch
12
+ import numpy as np
13
+ import pandas as pd
14
+ import time
15
+ import transformers
16
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
+ from sklearn.manifold import TSNE
18
+ from copy import deepcopy, copy
19
+ import seaborn as sns
20
+ import matplotlib.pylab as plt
21
+ from pprint import pprint
22
+ import shutil
23
+ import datetime
24
+ import re
25
+ import json
26
+ from pathlib import Path
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.utils.data import Dataset, DataLoader
30
+ from torch import nn
31
+ from torch.nn import functional as F
32
+ from torch.utils.data import TensorDataset, DataLoader, RandomSampler
33
+ from torch.optim import Adam
34
+ from torch.nn import CrossEntropyLoss
35
+ from transformers import BertForSequenceClassification
36
+ from copy import deepcopy
37
+ import gc
38
+ from sklearn.metrics import accuracy_score
39
+ import torch
40
+ import numpy as np
41
+ import torchmetrics
42
+ from torchmetrics import functional as fn
43
+
44
+
45
+ SEED = 2222
46
+
47
+ gen_seed = torch.Generator().manual_seed(SEED)
48
+
49
+
50
+ # Random seed function
51
+ def random_seed(value):
52
+ torch.backends.cudnn.deterministic=True
53
+ torch.manual_seed(value)
54
+ torch.cuda.manual_seed(value)
55
+ np.random.seed(value)
56
+ random.seed(value)
57
+
58
+ # Batch creation function
59
+ def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4):
60
+ idxs = list(range(0,len(taskset)))
61
+ if is_shuffle:
62
+ random.shuffle(idxs)
63
+ for i in range(0,len(idxs), batch_size):
64
+ yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))]
65
+
66
+
67
+
68
+ def prepare_data(data, batch_size,tokenizer,max_seq_length,
69
+ input = 'text', output = 'label',
70
+ train_size_per_class = 5):
71
+ data = data.reset_index().drop("index", axis=1)
72
+
73
+ labaled_data = data.loc[~data['label'].isna()]
74
+
75
+ data_train = labaled_data.groupby('label').sample(train_size_per_class)
76
+
77
+ rest_labaled_data = labaled_data.loc[~labaled_data.index.isin(data_train.index),:]
78
+ unlabaled_data = data.loc[data['label'].isna()]
79
+
80
+ data_test = pd.concat([rest_labaled_data, unlabaled_data])
81
+
82
+
83
+ # Train
84
+ ## Transforma em dataset
85
+ dataset_train = SLR_DataSet(
86
+ data = data_train.sample(frac=1),
87
+ input = input,
88
+ output = output,
89
+ tokenizer=tokenizer,
90
+ max_seq_length =max_seq_length)
91
+
92
+ # Test
93
+ # Dataloaders
94
+ ## Transforma em dataset
95
+ dataset_test = SLR_DataSet(
96
+ data = data_test,
97
+ input = input,
98
+ output = output,
99
+ tokenizer=tokenizer,
100
+ max_seq_length =max_seq_length)
101
+
102
+ # Dataloaders
103
+ ## Treino
104
+ data_train_loader = DataLoader(dataset_train,
105
+ shuffle=True,
106
+ batch_size=batch_size['train']
107
+ )
108
+
109
+ if len(dataset_test) % batch_size['test'] == 1 :
110
+ data_test_loader = DataLoader(dataset_test,
111
+ batch_size=batch_size['test'],
112
+ drop_last=True)
113
+ else:
114
+ data_test_loader = DataLoader(dataset_test,
115
+ batch_size=batch_size['test'],
116
+ drop_last=False)
117
+
118
+ return data_train_loader, data_test_loader, data_train, data_test
119
+
120
+
121
+
122
+
123
+
124
+ from tqdm import tqdm
125
+
126
+ def meta_train(data, model, device, Info, print_epoch =True, size_layer=0, Test_resource =None):
127
+
128
+ learner = Learner(model = model, device = device, **Info)
129
+
130
+ # Testing tasks
131
+ if isinstance(Test_resource, pd.DataFrame):
132
+ test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10,
133
+ training=False, **Info)
134
+
135
+
136
+ torch.clear_autocast_cache()
137
+ gc.collect()
138
+ torch.cuda.empty_cache()
139
+
140
+ # Meta epoca
141
+ for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80):
142
+ # print("Meta Epoca:", epoch)
143
+
144
+ # Tarefas de treino
145
+ train = MetaTask(data,
146
+ num_task = Info['num_task_train'],
147
+ k_support=Info['k_qry'],
148
+ k_query=Info['k_spt'], **Info)
149
+
150
+ # Batchs de tarefas
151
+ db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"])
152
+
153
+ if print_epoch:
154
+ # Outer loop bach training
155
+ for step, task_batch in enumerate(db):
156
+ print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n")
157
+ # meta-feedfoward
158
+ acc = learner(task_batch, valid_train= print_epoch)
159
+ print('Step:', step, '\ttraining Acc:', acc)
160
+ if isinstance(Test_resource, pd.DataFrame):
161
+ # Validating Model
162
+ if ((epoch+1) % 4) + step == 0:
163
+ random_seed(123)
164
+ print("\n-----------------Testing Mode-----------------\n")
165
+ db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1)
166
+ acc_all_test = []
167
+
168
+ # Looping testing tasks
169
+ for test_batch in db_test:
170
+ acc = learner(test_batch, training = False)
171
+ acc_all_test.append(acc)
172
+
173
+ print('Test acc:', np.mean(acc_all_test))
174
+ del acc_all_test, db_test
175
+
176
+ # Restarting training randomly
177
+ random_seed(int(time.time() % 10))
178
+
179
+
180
+ else:
181
+ for step, task_batch in enumerate(db):
182
+ acc = learner(task_batch, print_epoch, valid_train= print_epoch)
183
+
184
+ torch.clear_autocast_cache()
185
+ gc.collect()
186
+ torch.cuda.empty_cache()
187
+
188
+
189
+
190
+ def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name'):
191
+ # Inicia o modelo
192
+ model_meta = deepcopy(model)
193
+ optimizer = Adam(model_meta.parameters(), lr=lr)
194
+
195
+ model_meta.to(device)
196
+ model_meta.train()
197
+
198
+ # Loop de treino da tarefa
199
+ for i in range(0, epoch):
200
+ all_loss = []
201
+
202
+ # Inner training batch (support set)
203
+ for inner_step, batch in enumerate(data_train_loader):
204
+ batch = tuple(t.to(device) for t in batch)
205
+ input_ids, attention_mask,q_token_type_ids, label_id = batch
206
+
207
+ # Feedfoward
208
+ loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
209
+
210
+ # Calcula gradientes
211
+ loss.backward()
212
+
213
+ # Atualiza os parametros
214
+ optimizer.step()
215
+ optimizer.zero_grad()
216
+
217
+ all_loss.append(loss.item())
218
+
219
+
220
+ if (i % 2 == 0) & print_info:
221
+ print("Loss: ", np.mean(all_loss))
222
+
223
+
224
+ # Predicao no banco de teste
225
+ model_meta.eval()
226
+ all_loss = []
227
+ # all_acc = []
228
+ features = []
229
+ labels = []
230
+ predi_logit = []
231
+
232
+ with torch.no_grad():
233
+ for inner_step, batch in enumerate(tqdm(data_test_loader,
234
+ desc="Test validation | " + name,
235
+ ncols=80)) :
236
+ batch = tuple(t.to(device) for t in batch)
237
+ input_ids, attention_mask,q_token_type_ids, label_id = batch
238
+
239
+ # Predicoes
240
+ _, feature, prediction = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze())
241
+
242
+ prediction = prediction.detach().cpu().squeeze()
243
+ label_id = label_id.detach().cpu()
244
+ logit = feature[1].detach().cpu()
245
+ feature_lat = feature[0].detach().cpu()
246
+
247
+ labels.append(label_id.numpy().squeeze())
248
+ features.append(feature_lat.numpy())
249
+ predi_logit.append(logit.numpy())
250
+
251
+ # acc = fn.accuracy(prediction, label_id).item()
252
+ # all_acc.append(acc)
253
+ del input_ids, attention_mask, label_id, batch
254
+
255
+ # if print_info:
256
+ # print("acc:", np.mean(all_acc))
257
+
258
+ model_meta.to('cpu')
259
+ gc.collect()
260
+ torch.cuda.empty_cache()
261
+
262
+ del model_meta, optimizer
263
+
264
+
265
+ features = np.concatenate(np.array(features,dtype=object))
266
+ labels = np.concatenate(np.array(labels,dtype=object))
267
+ logits = np.concatenate(np.array(predi_logit,dtype=object))
268
+
269
+ features = torch.tensor(features.astype(np.float32)).detach().clone()
270
+ labels = torch.tensor(labels.astype(int)).detach().clone()
271
+ logits = torch.tensor(logits.astype(np.float32)).detach().clone()
272
+
273
+ # Reducao de dimensionalidade
274
+ X_embedded = TSNE(n_components=2, learning_rate='auto',
275
+ init='random').fit_transform(features.detach().clone())
276
+
277
+ return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone()
278
+
279
+
280
+ def wss_calc(logit, labels, trsh = 0.5):
281
+
282
+ # Predicao com base nos treshould
283
+ predict_trash = torch.sigmoid(logit).squeeze() >= trsh
284
+ CM = confusion_matrix(labels, predict_trash.to(int) )
285
+ tn, fp, fne, tp = CM.ravel()
286
+
287
+ P = (tp + fne)
288
+ N = (tn + fp)
289
+ recall = tp/(tp+fne)
290
+
291
+ # Wss antigo
292
+ wss_old = (tn + fne)/len(labels) -(1- recall)
293
+
294
+ # WSS novo
295
+ wss_new = (tn/N - fne/P)
296
+
297
+ return {
298
+ "wss": round(wss_old,4),
299
+ "awss": round(wss_new,4),
300
+ "R": round(recall,4),
301
+ "CM": CM
302
+ }
303
+
304
+
305
+
306
+
307
+ from sklearn.metrics import confusion_matrix
308
+ from torchmetrics import functional as fn
309
+ import matplotlib.pyplot as plt
310
+ from sklearn.metrics import roc_curve, auc
311
+ from sklearn.metrics import roc_auc_score
312
+ import ipywidgets as widgets
313
+ from IPython.display import HTML, display, clear_output
314
+ import matplotlib.pyplot as plt
315
+ import seaborn as sns
316
+ import warnings
317
+
318
+ warnings.simplefilter(action='ignore', category=FutureWarning)
319
+
320
+ def plot(logits, X_embedded, labels, tresh, show = True,
321
+ namefig = "plot", make_plot = True, print_stats = True, save = True):
322
+ col = pd.MultiIndex.from_tuples([
323
+ ("Predict", "0"),
324
+ ("Predict", "1")
325
+ ])
326
+ index = pd.MultiIndex.from_tuples([
327
+ ("Real", "0"),
328
+ ("Real", "1")
329
+ ])
330
+
331
+ predict = torch.sigmoid(logits).detach().clone()
332
+
333
+ roc_auc = dict()
334
+
335
+ fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
336
+
337
+ # Sem especificar o tresh
338
+ # WSS
339
+ ## indice do recall 0.95
340
+ idx_wss95 = sum(tpr < 0.95)
341
+ thresholds95 = thresholds[idx_wss95]
342
+
343
+ wss95_info = wss_calc(logits,labels, thresholds95 )
344
+ acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95)
345
+ f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95)
346
+
347
+
348
+ # Especificando o tresh
349
+ # Treshold avaliation
350
+
351
+
352
+ ## WSS
353
+ wss_info = wss_calc(logits,labels, tresh )
354
+ # Accuraci
355
+ acc_wssR = fn.accuracy(predict, labels, threshold=tresh)
356
+ f1_wssR = fn.f1_score(predict, labels, threshold=tresh)
357
+
358
+
359
+ metrics= {
360
+ # WSS
361
+ "WSS@95": wss95_info['wss'],
362
+ "AWSS@95": wss95_info['awss'],
363
+ "WSS@R": wss_info['wss'],
364
+ "AWSS@R": wss_info['awss'],
365
+ # Recall
366
+ "Recall_WSS@95": wss95_info['R'],
367
+ "Recall_WSS@R": wss_info['R'],
368
+ # acc
369
+ "acc@95": acc_wss95.item(),
370
+ "acc@R": acc_wssR.item(),
371
+ # f1
372
+ "f1@95": f1_wss95.item(),
373
+ "f1@R": f1_wssR.item(),
374
+ # treshould 95
375
+ "treshould@95": thresholds95
376
+ }
377
+
378
+ # print stats
379
+
380
+ if print_stats:
381
+ wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}"
382
+ wss95_adj= f"ASSWSS@95:{wss95_info['awss']}"
383
+ print(wss95)
384
+ print(wss95_adj)
385
+ print('Acc.:', round(acc_wss95.item(), 4))
386
+ print('F1-score:', round(f1_wss95.item(), 4))
387
+ print(f"Treshold to wss95: {round(thresholds95, 4)}")
388
+ cm = pd.DataFrame(wss95_info['CM'],
389
+ index=index,
390
+ columns=col)
391
+
392
+ print("\nConfusion matrix:")
393
+ print(cm)
394
+ print("\n---Metrics with threshold:", tresh, "----\n")
395
+ wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}"
396
+ print(wss)
397
+ wss_adj= f"AWSS@R:{wss_info['awss']}"
398
+ print(wss_adj)
399
+ print('Acc.:', round(acc_wssR.item(), 4))
400
+ print('F1-score:', round(f1_wssR.item(), 4))
401
+ cm = pd.DataFrame(wss_info['CM'],
402
+ index=index,
403
+ columns=col)
404
+
405
+ print("\nConfusion matrix:")
406
+ print(cm)
407
+
408
+
409
+ # Graficos
410
+
411
+ if make_plot:
412
+
413
+ fig, axes = plt.subplots(1, 4, figsize=(25,10))
414
+ alpha = torch.squeeze(predict).numpy()
415
+
416
+ # plots
417
+
418
+ p1 = sns.scatterplot(x=X_embedded[:, 0],
419
+ y=X_embedded[:, 1],
420
+ hue=labels,
421
+ alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE')
422
+
423
+ t_wss = predict >= thresholds95
424
+ t_wss = t_wss.squeeze().numpy()
425
+
426
+ p2 = sns.scatterplot(x=X_embedded[t_wss, 0],
427
+ y=X_embedded[t_wss, 1],
428
+ hue=labels[t_wss],
429
+ alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95')
430
+
431
+ t = predict >= tresh
432
+ t = t.squeeze().numpy()
433
+
434
+ p3 = sns.scatterplot(x=X_embedded[t, 0],
435
+ y=X_embedded[t, 1],
436
+ hue=labels[t],
437
+ alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-Treshold {tresh}')
438
+
439
+
440
+ roc_auc = auc(fpr, tpr)
441
+ lw = 2
442
+
443
+ axes[3].plot(
444
+ fpr,
445
+ tpr,
446
+ color="darkorange",
447
+ lw=lw,
448
+ label="ROC curve (area = %0.2f)" % roc_auc)
449
+
450
+ axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
451
+ axes[3].axhline(y=0.95, color='r', linestyle='-')
452
+ axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate", title= "ROC")
453
+ axes[3].legend(loc="lower right")
454
+
455
+ if show:
456
+ plt.show()
457
+
458
+ if save:
459
+ fig.savefig(namefig, dpi=fig.dpi)
460
+
461
+ return metrics
462
+
463
+ def auc_plot(logits,labels, color = "darkorange", label = "test"):
464
+ predict = torch.sigmoid(logits).detach().clone()
465
+ fpr, tpr, thresholds = roc_curve(labels, predict.squeeze())
466
+ roc_auc = auc(fpr, tpr)
467
+ lw = 2
468
+
469
+ label = label + str(round(roc_auc,2))
470
+ # print(label)
471
+
472
+ plt.plot(
473
+ fpr,
474
+ tpr,
475
+ color=color,
476
+ lw=lw,
477
+ label= label
478
+ )
479
+ plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
480
+ plt.axhline(y=0.95, color='r', linestyle='-')
481
+
482
+
483
+ from sklearn.metrics import confusion_matrix
484
+ from torchmetrics import functional as fn
485
+ import matplotlib.pyplot as plt
486
+ from sklearn.metrics import roc_curve, auc
487
+ from sklearn.metrics import roc_auc_score
488
+ import ipywidgets as widgets
489
+ from IPython.display import HTML, display, clear_output
490
+ import matplotlib.pyplot as plt
491
+ import seaborn as sns
492
+ import warnings
493
+
494
+
495
+ class diagnosis():
496
+ def __init__(self, names, Valid_resource, batch_size_test, model,Info,start = 0):
497
+ self.names=names
498
+ self.Valid_resource=Valid_resource
499
+ self.batch_size_test=batch_size_test
500
+ self.model=model
501
+ self.start=start
502
+
503
+ self.value_trash = widgets.FloatText(
504
+ value=0.95,
505
+ description='tresh',
506
+ disabled=False
507
+ )
508
+
509
+ self.valueb = widgets.IntText(
510
+ value=10,
511
+ description='size',
512
+ disabled=False
513
+ )
514
+
515
+ self.train_b = widgets.Button(description="Train")
516
+ self.next_b = widgets.Button(description="Next")
517
+ self.eval_b = widgets.Button(description="Evaluation")
518
+
519
+ self.hbox = widgets.HBox([self.train_b, self.valueb])
520
+
521
+ self.next_b.on_click(self.Next_button)
522
+ self.train_b.on_click(self.Train_button)
523
+ self.eval_b.on_click(self.Evaluation_button)
524
+
525
+
526
+ # Next button
527
+ def Next_button(self,p):
528
+ clear_output()
529
+ self.i=self.i+1
530
+
531
+ # global domain
532
+ self.domain = names[self.i]
533
+ print("Name:", self.domain)
534
+
535
+ # global data
536
+ self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain]
537
+ print(self.data['label'].value_counts())
538
+
539
+ display(self.hbox)
540
+ display(self.next_b)
541
+
542
+ # Train button
543
+ def Train_button(self, y):
544
+ clear_output()
545
+ print(self.domain)
546
+
547
+ # Preparing data for training
548
+ self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data,
549
+ train_size_per_class = self.valueb.value,
550
+ batch_size = {'train': Info['inner_batch_size'],
551
+ 'test': batch_size_test},
552
+ max_seq_length = Info['max_seq_length'],
553
+ tokenizer = Info['tokenizer'],
554
+ input = "text",
555
+ output = "label")
556
+
557
+ self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader,
558
+ model, device,
559
+ epoch = Info['inner_update_step'],
560
+ lr=Info['inner_update_lr'],
561
+ print_info=True,
562
+ name = self.domain)
563
+
564
+ tresh_box = widgets.HBox([self.eval_b, self.value_trash])
565
+ display(self.hbox)
566
+ display(tresh_box)
567
+ display(self.next_b)
568
+
569
+ # Evaluation button
570
+ def Evaluation_button(self, te):
571
+ clear_output()
572
+ tresh_box = widgets.HBox([self.eval_b, self.value_trash])
573
+
574
+ print(self.domain)
575
+ # print("\n")
576
+ print("-------Train data-------")
577
+ print(self.data_train['label'].value_counts())
578
+ print("-------Test data-------")
579
+ print(self.data_test['label'].value_counts())
580
+ # print("\n")
581
+
582
+ display(self.next_b)
583
+ display(tresh_box)
584
+ display(self.hbox)
585
+
586
+
587
+ metrics = plot(self.logits, self.X_embedded, self.labels,
588
+ tresh=Info['tresh'], show = True,
589
+ # namefig= "./"+base_path +"/"+"Results/size_layer/"+ name_domain+'/' +str(n_layers) + '/img/' + str(attempt) + 'plots',
590
+ namefig= 'test',
591
+ make_plot = True,
592
+ print_stats = True,
593
+ save=False)
594
+
595
+ def __call__(self):
596
+ self.i= self.start-1
597
+
598
+ clear_output()
599
+ display(self.next_b)