SS3M commited on
Commit
16b4b69
·
verified ·
1 Parent(s): da27686

Upload 4.2_add_span_rerank_branch_4.3's state dict

Browse files
.gitattributes CHANGED
@@ -92,3 +92,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
92
  20_sampled_instead_all_27/logs/20_sampled_instead_all_27_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
93
  4_margin_loss_C_4.1/logs/4_margin_loss_C_4.1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
 
 
92
  20_sampled_instead_all_27/logs/20_sampled_instead_all_27_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
93
  4_margin_loss_C_4.1/logs/4_margin_loss_C_4.1_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
94
  4.1_entities_no_ce_4.2/logs/4.1_entities_no_ce_4.2_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
95
+ 4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg filter=lfs diff=lfs merge=lfs -text
4.2_add_span_rerank_branch_4.3/4.2_add_span_rerank_branch_4.3.py ADDED
@@ -0,0 +1,2692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [code]
2
+ get_ipython().system('pip install evaluate seqeval underthesea positional-encodings[pytorch] pytorch-crf')
3
+
4
+ # %% [code]
5
+ import warnings
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
12
+ import torch.nn.functional as F
13
+ import albumentations as albu
14
+ from transformers import AutoTokenizer, AutoModel
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from positional_encodings.torch_encodings import PositionalEncoding1D
18
+ from torchcrf import CRF
19
+
20
+ from sklearn.metrics import f1_score
21
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
22
+ from scipy.spatial.transform import Rotation as R
23
+ from sklearn.model_selection import KFold, StratifiedGroupKFold, GroupKFold, StratifiedKFold
24
+ from sklearn.metrics import precision_recall_fscore_support
25
+ from timm.utils import ModelEmaV3
26
+ import timm
27
+
28
+ import os
29
+ import gc
30
+ import json
31
+ from pathlib import Path
32
+ import pickle
33
+ from tqdm.auto import tqdm
34
+ import copy
35
+ import numpy as np
36
+ import pandas as pd
37
+ import polars as pl
38
+ from PIL import Image
39
+ import time
40
+ from tqdm import tqdm
41
+ from matplotlib import pyplot as plt
42
+ import seaborn as sns
43
+ from multiprocessing import Manager as MemoryManager
44
+ from functools import lru_cache
45
+ import shutil
46
+ import glob
47
+ import cv2
48
+ import random
49
+ import re
50
+ import joblib
51
+ import math
52
+ from huggingface_hub import HfApi, snapshot_download
53
+ import evaluate
54
+ from underthesea import word_tokenize as vi_tokenize_tool
55
+ import spacy
56
+ en_tokenize_tool = spacy.load("en_core_web_sm")
57
+ from collections import defaultdict, Counter
58
+
59
+ # %% [code]
60
+ # Global config
61
+ SEEDS = [26092004]
62
+ topk = 1
63
+ nfolds = 5
64
+ only_fold_idx = 0
65
+ test_only = 0
66
+ debug_only = 0
67
+
68
+ # Config thư mục
69
+ dataset = 'kltn/only_entities' # conll003, ontonotes, phoner, vietbio, vietmed, vimed, kltn/only_entities, kltn/raw
70
+ root_dir = f'/kaggle/input/notebooks/sambui22022517/kltn-data/{dataset}' ## Thư mục chứa file train, val, test
71
+ train_dir = f'{root_dir}'
72
+ # val_dir = f'{root_dir}/val'
73
+ test_dir = f'{root_dir}'
74
+
75
+ # Config checkpoints
76
+
77
+ # Config training
78
+ epochs = 18 if not debug_only else 2
79
+ batch_size = 32
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ # # Thêm biến toàn cục nào đó vào đây
82
+ repo_name = 'SS3M/kltn-experiments'
83
+ state_dict_save_name = "4.2_add_span_rerank_branch_4.3"
84
+ checkpoints_dir = state_dict_save_name
85
+ pretrained_dir = "/kaggle/working"
86
+ os.makedirs(f'{checkpoints_dir}', exist_ok=True)
87
+
88
+ backbone_model_name = "bert-base-uncased" if dataset in ["conll003", "ontonotes"] else "vinai/phobert-base"
89
+ word_tokenize = lambda text: [token.text for token in en_tokenize_tool(text)] if dataset == dataset in ["conll003", "ontonotes"] else vi_tokenize_tool(text)
90
+ max_len_dict = {
91
+ 'kltn/raw': 256,
92
+ 'kltn/only_entities': 68,
93
+ 'conll003': 46,
94
+ 'ontonotes': 61,
95
+ 'phoner': 68,
96
+ 'vietbio': 125,
97
+ 'vietmed': 36,
98
+ 'vimed': 100,
99
+ }
100
+ zero_entities_rate_dict = {
101
+ 'kltn/raw': 1000,
102
+ 'kltn/only_entities': 0.2,
103
+ 'conll003': 1000, # mean keep all zero-entities samples
104
+ 'ontonotes': 1000,
105
+ 'phoner': 1000,
106
+ 'vietbio': 1000,
107
+ 'vietmed': 1000,
108
+ 'vimed': 1000,
109
+ }
110
+
111
+ max_len = max_len_dict[dataset]
112
+ max_n_parts = 1
113
+ max_span_len = 10
114
+ zero_entities_rate = zero_entities_rate_dict[dataset]
115
+
116
+ # Trainer
117
+ trainer_params = {
118
+ "training_time": "00:11:30:00",
119
+ "eval_mode": "max",
120
+ "topk": topk,
121
+ "save_name": state_dict_save_name,
122
+ "save_best": True,
123
+ "save_last": True,
124
+ "device": device,
125
+ "logging": True,
126
+ "logging_file": True,
127
+ "checkpoints_dir": checkpoints_dir,
128
+ "early_stopping": 30,
129
+ "eval_from_ratio": 0.4,
130
+ "eval_every": 1,
131
+ "schedule_in_step": False,
132
+ "use_ema": True,
133
+ "ema_from_ratio": 0.3,
134
+ "ema_decay": 0.9995,
135
+ "max_grad_norm": 200.0,
136
+ "return_best": True,
137
+ "return_last": True,
138
+ }
139
+
140
+ # Memory
141
+ train_memory_params = {
142
+ 'max_len': max_len,
143
+ 'max_n_parts': max_n_parts,
144
+ }
145
+ val_memory_params = {
146
+ 'max_len': max_len,
147
+ 'max_n_parts': max_n_parts,
148
+ }
149
+
150
+ # Data Loader
151
+ def seed_worker(worker_id):
152
+ worker_seed = torch.initial_seed() % 2**32
153
+ np.random.seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ train_loader_params = {
157
+ 'batch_size': batch_size,
158
+ 'shuffle': True,
159
+ 'pin_memory':True,
160
+ 'num_workers': 2,
161
+ 'drop_last': False,
162
+ 'worker_init_fn': seed_worker,
163
+ 'persistent_workers': False,
164
+ }
165
+ val_loader_params = {
166
+ 'batch_size': batch_size,
167
+ 'shuffle': False,
168
+ 'pin_memory':True,
169
+ 'num_workers': 1,
170
+ 'drop_last': False,
171
+ 'worker_init_fn': seed_worker,
172
+ 'persistent_workers': False,
173
+ }
174
+
175
+ # Model
176
+ model_params = {
177
+ 'backbone_model_name': backbone_model_name,
178
+ 'max_r': 6,
179
+ }
180
+
181
+ # Loss Func
182
+ loss_func_params = {
183
+ 'lambda_span': 1.0,
184
+ 'lambda_margin': 1.0,
185
+ 'margin': 0.5,
186
+ }
187
+ eval_func_params = {}
188
+
189
+ # Optim
190
+ optim_params = {
191
+ 'name': 'AdamW',
192
+ 'lr': 1e-4,
193
+ 'weight_decay': 1e-4,
194
+ }
195
+ scheduler_params = {
196
+ 'name': 'CosineAnnealingLR',
197
+ 'T_max': 20, # Số epoch để hoàn thành một chu kỳ giảm LR
198
+ 'eta_min': 1e-6 # Learning rate nhỏ nhất trong chu kỳ
199
+ }
200
+
201
+ # %% [code]
202
+ def set_seed(seed=42):
203
+ random.seed(seed)
204
+ np.random.seed(seed)
205
+ torch.manual_seed(seed)
206
+ torch.cuda.manual_seed(seed)
207
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
208
+ torch.use_deterministic_algorithms(False)
209
+ torch.backends.cudnn.deterministic = True
210
+ torch.backends.cudnn.benchmark = False
211
+ os.environ['PYTHONHASHSEED'] = str(seed)
212
+
213
+ # %% [code]
214
+ class CustomLoss(nn.Module):
215
+ def __init__(
216
+ self,
217
+ lambda_margin=1.0,
218
+ lambda_span=1.0,
219
+ margin=1.0
220
+ ):
221
+ super().__init__()
222
+
223
+ self.lambda_margin = lambda_margin
224
+ self.lambda_span = lambda_span
225
+ self.margin = margin
226
+
227
+ # =========================================================
228
+ # token margin
229
+ # =========================================================
230
+
231
+ def margin_loss(self, logits, labels):
232
+ """
233
+ logits: (N, C)
234
+ labels: (N,)
235
+ """
236
+
237
+ valid_mask = labels != -100
238
+
239
+ logits = logits[valid_mask]
240
+ labels = labels[valid_mask]
241
+
242
+ if len(labels) == 0:
243
+ return logits.new_tensor(0.0)
244
+
245
+ # positive logit
246
+ pos_logit = logits.gather(
247
+ dim=1,
248
+ index=labels.unsqueeze(-1)
249
+ ).squeeze(-1)
250
+
251
+ # hardest negative
252
+ neg_logits = logits.clone()
253
+
254
+ neg_logits.scatter_(
255
+ 1,
256
+ labels.unsqueeze(-1),
257
+ float("-inf")
258
+ )
259
+
260
+ hardest_neg = neg_logits.max(dim=-1).values
261
+
262
+ # margin ranking
263
+ loss = F.relu(
264
+ self.margin - pos_logit + hardest_neg
265
+ )
266
+
267
+ return loss.mean()
268
+
269
+ # =========================================================
270
+ # span listwise margin
271
+ # =========================================================
272
+
273
+ def span_loss(self, span_scores, labels):
274
+ """
275
+ span_scores: (B, N, M)
276
+ labels: (B, N, M)
277
+
278
+ labels:
279
+ -100 -> ignore
280
+ 0 -> negative
281
+ >0 -> positive
282
+ """
283
+
284
+ device = span_scores.device
285
+
286
+ B, N, M = span_scores.shape
287
+ if N == 0 or M == 0:
288
+ return span_scores.new_tensor(0.0)
289
+
290
+ # =====================================================
291
+ # masks
292
+ # =====================================================
293
+
294
+ valid_mask = labels != -100
295
+ pos_mask = labels > 0
296
+ neg_mask = labels == 0
297
+
298
+ # =====================================================
299
+ # mỗi N phải có ít nhất 1 positive
300
+ # =====================================================
301
+
302
+ has_pos = pos_mask.any(dim=-1)
303
+ # (B, N)
304
+
305
+ # =====================================================
306
+ # masked scores
307
+ # =====================================================
308
+ pos_scores = span_scores.masked_fill(
309
+ ~pos_mask,
310
+ float("-inf")
311
+ )
312
+
313
+ neg_scores = span_scores.masked_fill(
314
+ ~neg_mask,
315
+ float("-inf")
316
+ )
317
+
318
+ # =====================================================
319
+ # hardest positive / hardest negative
320
+ # =====================================================
321
+
322
+ best_pos = pos_scores.max(dim=-1).values
323
+ # (B, N)
324
+
325
+ best_neg = neg_scores.max(dim=-1).values
326
+ # (B, N)
327
+
328
+ # =====================================================
329
+ # nếu không có negative
330
+ # =====================================================
331
+
332
+ no_neg = ~neg_mask.any(dim=-1)
333
+
334
+ best_neg = torch.where(
335
+ no_neg,
336
+ torch.zeros_like(best_neg),
337
+ best_neg
338
+ )
339
+
340
+ # =====================================================
341
+ # margin ranking
342
+ # =====================================================
343
+
344
+ loss = F.relu(
345
+ self.margin - best_pos + best_neg
346
+ )
347
+
348
+ # =====================================================
349
+ # chỉ tính những N có positive
350
+ # =====================================================
351
+
352
+ loss = loss[has_pos]
353
+
354
+ if loss.numel() == 0:
355
+ return span_scores.new_tensor(0.0)
356
+
357
+ return loss.mean()
358
+
359
+ # =========================================================
360
+ # forward
361
+ # =========================================================
362
+
363
+ def forward(
364
+ self,
365
+ start_logits, start_labels,
366
+ end_logits, end_labels,
367
+ span_scores, labels
368
+ ):
369
+
370
+ # =====================================================
371
+ # flatten token logits
372
+ # =====================================================
373
+
374
+ B, L, C = start_logits.shape
375
+
376
+ start_logits_flat = start_logits.view(B * L, C)
377
+ start_labels_flat = start_labels.view(-1)
378
+
379
+ end_logits_flat = end_logits.view(B * L, C)
380
+ end_labels_flat = end_labels.view(-1)
381
+
382
+ # =====================================================
383
+ # token margin
384
+ # =====================================================
385
+
386
+ start_margin = self.margin_loss(
387
+ start_logits_flat,
388
+ start_labels_flat
389
+ )
390
+
391
+ end_margin = self.margin_loss(
392
+ end_logits_flat,
393
+ end_labels_flat
394
+ )
395
+
396
+ token_margin_loss = (
397
+ start_margin + end_margin
398
+ )
399
+
400
+ # =====================================================
401
+ # span loss
402
+ # =====================================================
403
+
404
+ span_loss = self.span_loss(
405
+ span_scores,
406
+ labels
407
+ )
408
+
409
+ # =====================================================
410
+ # total
411
+ # =====================================================
412
+
413
+ total_loss = (
414
+ self.lambda_margin * token_margin_loss
415
+ +
416
+ self.lambda_span * span_loss
417
+ )
418
+
419
+ return {
420
+ "total": total_loss,
421
+ "token_margin_loss": token_margin_loss,
422
+ "span_loss": span_loss,
423
+ "start_margin": start_margin,
424
+ "end_margin": end_margin,
425
+ }
426
+
427
+ # %% [code]
428
+ ## Viết eval_fn vào đây
429
+
430
+ # Bỏ hết eval_fn và trọng số vào đây
431
+ class CustomEvalFn(nn.Module):
432
+ def __init__(self):
433
+ super().__init__()
434
+
435
+ def compute_f1(self, tp, fp, fn):
436
+ precision = tp / (tp + fp + 1e-8)
437
+ recall = tp / (tp + fn + 1e-8)
438
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
439
+ return precision, recall, f1
440
+
441
+ def forward(self, pred, gold):
442
+ pred_set = set(pred)
443
+ gold_set = set(gold)
444
+
445
+ tp = len(pred_set & gold_set)
446
+ fp = len(pred_set - gold_set)
447
+ fn = len(gold_set - pred_set)
448
+
449
+ precision, recall, f1 = self.compute_f1(tp, fp, fn)
450
+
451
+ return {
452
+ f"precision": precision,
453
+ f"recall": recall,
454
+ f"f1": f1,
455
+ }
456
+
457
+ class SpanErrorAnalyzer:
458
+ def __init__(self, pad_token_id=0):
459
+ self.pad_token_id = pad_token_id
460
+
461
+ # ===== helper =====
462
+ def _to_set(self, data):
463
+ """
464
+ data: list of (b, tuple(ids))
465
+ -> dict[b] = set(tuple(ids))
466
+ """
467
+ res = defaultdict(set)
468
+ for b, ids in data:
469
+ ids = tuple([i for i in ids if i != self.pad_token_id])
470
+ if len(ids) > 0:
471
+ res[b].add(ids)
472
+ return res
473
+
474
+ def _iou(self, a, b):
475
+ """
476
+ a, b: tuple(ids)
477
+ """
478
+ set_a, set_b = set(a), set(b)
479
+ inter = len(set_a & set_b)
480
+ union = len(set_a | set_b)
481
+ if union == 0:
482
+ return 0.0
483
+ return inter / union
484
+
485
+ def _boundary_error(self, pred, gold):
486
+ """
487
+ đo lệch boundary dựa trên overlap prefix/suffix
488
+ """
489
+ # left match
490
+ left = 0
491
+ for i in range(min(len(pred), len(gold))):
492
+ if pred[i] == gold[i]:
493
+ left += 1
494
+ else:
495
+ break
496
+
497
+ # right match
498
+ right = 0
499
+ for i in range(1, min(len(pred), len(gold)) + 1):
500
+ if pred[-i] == gold[-i]:
501
+ right += 1
502
+ else:
503
+ break
504
+
505
+ return {
506
+ "left_match": left,
507
+ "right_match": right,
508
+ "pred_len": len(pred),
509
+ "gold_len": len(gold),
510
+ }
511
+
512
+ # ===== main =====
513
+ def analyze(self, preds, golds):
514
+ pred_map = self._to_set(preds)
515
+ gold_map = self._to_set(golds)
516
+
517
+ all_batches = set(pred_map.keys()) | set(gold_map.keys())
518
+
519
+ stats = Counter()
520
+
521
+ detailed_errors = []
522
+
523
+ for b in all_batches:
524
+ pset = pred_map.get(b, set())
525
+ gset = gold_map.get(b, set())
526
+
527
+ matched_gold = set()
528
+
529
+ # ===== check predictions =====
530
+ for p in pset:
531
+ if p in gset:
532
+ stats["exact_match"] += 1
533
+ matched_gold.add(p)
534
+ else:
535
+ # tìm gold gần nhất
536
+ best_iou = 0
537
+ best_g = None
538
+
539
+ for g in gset:
540
+ iou = self._iou(p, g)
541
+ if iou > best_iou:
542
+ best_iou = iou
543
+ best_g = g
544
+
545
+ if best_iou > 0:
546
+ stats["partial_match"] += 1
547
+
548
+ boundary = self._boundary_error(p, best_g)
549
+
550
+ detailed_errors.append({
551
+ "type": "boundary_error",
552
+ "batch": b,
553
+ "pred": p,
554
+ "gold": best_g,
555
+ "iou": best_iou,
556
+ **boundary
557
+ })
558
+ else:
559
+ if b not in gold_map:
560
+ stats["no_event_sample"] += 1
561
+ err_type = "no_event_sample"
562
+ else:
563
+ stats["completely_wrong"] += 1
564
+ err_type = "completely_wrong"
565
+
566
+ detailed_errors.append({
567
+ "type": err_type,
568
+ "batch": b,
569
+ "pred": p
570
+ })
571
+
572
+ # ===== check missing =====
573
+ for g in gset:
574
+ if g not in matched_gold:
575
+ # check if any pred overlaps
576
+ overlap = any(self._iou(p, g) > 0 for p in pset)
577
+
578
+ if overlap:
579
+ stats["miss_with_overlap"] += 1
580
+ else:
581
+ stats["miss"] += 1
582
+
583
+ detailed_errors.append({
584
+ "type": "miss",
585
+ "batch": b,
586
+ "gold": g
587
+ })
588
+
589
+ return {
590
+ "summary": {
591
+ "exact_match": (stats["exact_match"], stats["exact_match"] / len(preds)),
592
+ "partial_match": (stats["partial_match"], stats["partial_match"] / len(preds)),
593
+ "no_event_sample": (stats["no_event_sample"], stats["no_event_sample"] / len(preds)),
594
+ "completely_wrong": (stats["completely_wrong"], stats["completely_wrong"] / len(preds)),
595
+ "miss": (stats["miss"], stats["miss"] / len(golds)),
596
+ "miss_with_overlap": (stats["miss_with_overlap"], stats["miss_with_overlap"] / len(golds)),
597
+ },
598
+ "details": detailed_errors
599
+ }
600
+
601
+ # %% [code]
602
+ class DataParallelProxy(nn.DataParallel):
603
+
604
+ def __getattr__(self, name):
605
+ try:
606
+ return super().__getattr__(name)
607
+
608
+ except AttributeError:
609
+
610
+ attr = getattr(self.module, name)
611
+
612
+ if callable(attr):
613
+
614
+ def wrapper(*args, **kwargs):
615
+ return self._parallel_apply_method(
616
+ name,
617
+ *args,
618
+ **kwargs
619
+ )
620
+
621
+ return wrapper
622
+
623
+ return attr
624
+
625
+ # =========================================================
626
+ # parallel custom method
627
+ # =========================================================
628
+
629
+ def _parallel_apply_method(self, method_name, *inputs, **kwargs):
630
+
631
+ if not self.device_ids:
632
+ return getattr(self.module, method_name)(*inputs, **kwargs)
633
+
634
+ inputs_scattered, kwargs_scattered = self.scatter(
635
+ inputs,
636
+ kwargs,
637
+ self.device_ids
638
+ )
639
+
640
+ replicas = self.replicate(
641
+ self.module,
642
+ self.device_ids[:len(inputs_scattered)]
643
+ )
644
+
645
+ outputs = self.parallel_apply(
646
+ [getattr(replica, method_name) for replica in replicas],
647
+ inputs_scattered,
648
+ kwargs_scattered
649
+ )
650
+
651
+ return self._custom_gather(
652
+ outputs,
653
+ self.output_device
654
+ )
655
+
656
+ # =========================================================
657
+ # OVERRIDE FORWARD GATHER
658
+ # =========================================================
659
+
660
+ def gather(self, outputs, output_device):
661
+
662
+ return self._custom_gather(
663
+ outputs,
664
+ output_device
665
+ )
666
+
667
+ # =========================================================
668
+ # recursive gather
669
+ # =========================================================
670
+
671
+ def _custom_gather(self, outputs, output_device):
672
+
673
+ first = outputs[0]
674
+
675
+ # =====================================================
676
+ # tensor
677
+ # =====================================================
678
+
679
+ if torch.is_tensor(first):
680
+
681
+ return self._gather_tensor(
682
+ outputs,
683
+ output_device
684
+ )
685
+
686
+ # =====================================================
687
+ # tuple
688
+ # =====================================================
689
+
690
+ if isinstance(first, tuple):
691
+
692
+ return tuple(
693
+ self._custom_gather(
694
+ list(items),
695
+ output_device
696
+ )
697
+ for items in zip(*outputs)
698
+ )
699
+
700
+ # =====================================================
701
+ # list
702
+ # =====================================================
703
+
704
+ if isinstance(first, list):
705
+
706
+ # list[tensor]
707
+ if len(first) > 0 and torch.is_tensor(first[0]):
708
+
709
+ return self._gather_tensor_list(
710
+ outputs,
711
+ output_device
712
+ )
713
+
714
+ merged = []
715
+
716
+ for out in outputs:
717
+ merged.extend(out)
718
+
719
+ return merged
720
+
721
+ # =====================================================
722
+ # dict
723
+ # =====================================================
724
+
725
+ if isinstance(first, dict):
726
+
727
+ return {
728
+ k: self._custom_gather(
729
+ [o[k] for o in outputs],
730
+ output_device
731
+ )
732
+ for k in first.keys()
733
+ }
734
+
735
+ # =====================================================
736
+ # fallback
737
+ # =====================================================
738
+
739
+ return outputs
740
+
741
+ # =========================================================
742
+ # gather tensor with auto pad
743
+ # =========================================================
744
+
745
+ def _gather_tensor(self, tensors, output_device):
746
+
747
+ # move same device first
748
+ tensors = [
749
+ t.to(output_device)
750
+ for t in tensors
751
+ ]
752
+
753
+ # =====================================================
754
+ # fast path
755
+ # =====================================================
756
+
757
+ try:
758
+ return torch.cat(tensors, dim=0)
759
+
760
+ except RuntimeError:
761
+ pass
762
+
763
+ # =====================================================
764
+ # auto max shape
765
+ # =====================================================
766
+
767
+ max_shape = list(tensors[0].shape)
768
+
769
+ for t in tensors[1:]:
770
+
771
+ for d in range(len(max_shape)):
772
+
773
+ max_shape[d] = max(
774
+ max_shape[d],
775
+ t.shape[d]
776
+ )
777
+
778
+ # =====================================================
779
+ # pad tensors
780
+ # =====================================================
781
+
782
+ padded = []
783
+
784
+ for t in tensors:
785
+
786
+ pad = []
787
+
788
+ # reverse order for F.pad
789
+ for d in reversed(range(len(max_shape))):
790
+
791
+ # never pad batch dim
792
+ if d == 0:
793
+ pad.extend([0, 0])
794
+ continue
795
+
796
+ diff = max_shape[d] - t.shape[d]
797
+
798
+ pad.extend([0, diff])
799
+
800
+ t = F.pad(t, pad)
801
+
802
+ padded.append(t)
803
+
804
+ return torch.cat(padded, dim=0)
805
+
806
+ # =========================================================
807
+ # list[tensor]
808
+ # =========================================================
809
+
810
+ def _gather_tensor_list(self, outputs, output_device):
811
+
812
+ merged = []
813
+
814
+ for out in outputs:
815
+ merged.extend(out)
816
+
817
+ return self._gather_tensor(
818
+ merged,
819
+ output_device
820
+ )
821
+
822
+ # %% [code]
823
+ ## Viết cấu trúc model vào đây
824
+ def extract_spans_and_labels(start_logits, end_logits):
825
+ """
826
+ Args:
827
+ start_logits: Tensor (B, L, C)
828
+ end_logits: Tensor (B, L, C)
829
+
830
+ Returns:
831
+ spans: Tensor (B, N, 2)
832
+ padding = (0, 0)
833
+
834
+ labels: Tensor (B, N)
835
+ padding = -100
836
+
837
+ Nếu không extract được span nào:
838
+ spans.shape = (B, 0, 2)
839
+ labels.shape = (B, 0)
840
+ """
841
+
842
+ start_labels = start_logits.argmax(dim=-1) # (B, L)
843
+ end_labels = end_logits.argmax(dim=-1) # (B, L)
844
+
845
+ B, L = start_labels.shape
846
+
847
+ batch_spans = []
848
+ batch_labels = []
849
+
850
+ max_n = 0
851
+
852
+ for bidx in range(B):
853
+
854
+ used_start = set()
855
+ used_end = set()
856
+
857
+ spans = []
858
+ labels = []
859
+
860
+ for s in range(L):
861
+
862
+ s_label = start_labels[bidx, s].item()
863
+
864
+ # bỏ qua O
865
+ if s_label == 0:
866
+ continue
867
+
868
+ if s in used_start:
869
+ continue
870
+
871
+ nearest_e = None
872
+
873
+ # tìm end gần nhất cùng class
874
+ for e in range(s, L):
875
+
876
+ if e in used_end:
877
+ continue
878
+
879
+ e_label = end_labels[bidx, e].item()
880
+
881
+ if e_label == s_label:
882
+ nearest_e = e
883
+ break
884
+
885
+ if nearest_e is None:
886
+ continue
887
+
888
+ used_start.add(s)
889
+ used_end.add(nearest_e)
890
+
891
+ spans.append([s, nearest_e])
892
+ labels.append(s_label)
893
+
894
+ batch_spans.append(spans)
895
+ batch_labels.append(labels)
896
+
897
+ max_n = max(max_n, len(spans))
898
+
899
+ # =========================================================
900
+ # không extract được gì
901
+ # =========================================================
902
+
903
+ if max_n == 0:
904
+
905
+ spans = torch.empty(
906
+ (B, 0, 2),
907
+ dtype=torch.long,
908
+ device=start_logits.device
909
+ )
910
+
911
+ labels = torch.empty(
912
+ (B, 0),
913
+ dtype=torch.long,
914
+ device=start_logits.device
915
+ )
916
+
917
+ return spans, labels
918
+
919
+ # =========================================================
920
+ # padding
921
+ # =========================================================
922
+
923
+ padded_spans = []
924
+ padded_labels = []
925
+
926
+ for spans, labels in zip(batch_spans, batch_labels):
927
+
928
+ pad_n = max_n - len(spans)
929
+
930
+ spans = spans + [[0, 0]] * pad_n
931
+ labels = labels + [-100] * pad_n
932
+
933
+ padded_spans.append(spans)
934
+ padded_labels.append(labels)
935
+
936
+ spans = torch.tensor(
937
+ padded_spans,
938
+ dtype=torch.long,
939
+ device=start_logits.device
940
+ ) # (B, N, 2)
941
+
942
+ labels = torch.tensor(
943
+ padded_labels,
944
+ dtype=torch.long,
945
+ device=start_logits.device
946
+ ) # (B, N)
947
+
948
+ return spans, labels
949
+
950
+ def expand_spans(spans, r, attention_mask):
951
+ """
952
+ Args:
953
+ spans: Tensor (B, N, 2)
954
+ r: Tensor (B, N)
955
+ attention_mask: Tensor (B, L)
956
+
957
+ Returns:
958
+ expanded_spans: Tensor (B, N, M, 2)
959
+
960
+ padding = (0, 0)
961
+
962
+ Nếu N = 0:
963
+ return shape (B, 0, 0, 2)
964
+ """
965
+
966
+ device = spans.device
967
+
968
+ B, N, _ = spans.shape
969
+ L = attention_mask.shape[1]
970
+
971
+ # =========================================================
972
+ # empty case
973
+ # =========================================================
974
+
975
+ if N == 0:
976
+ return torch.empty(
977
+ (B, 0, 0, 2),
978
+ dtype=torch.long,
979
+ device=device
980
+ )
981
+
982
+ max_r = r.max().item()
983
+ M = 2 * max_r + 1
984
+
985
+ # =========================================================
986
+ # shifts: [-max_r, ..., 0, ..., +max_r]
987
+ # =========================================================
988
+
989
+ shifts = torch.arange(
990
+ -max_r,
991
+ max_r + 1,
992
+ device=device
993
+ ) # (M,)
994
+
995
+ # =========================================================
996
+ # expand
997
+ # =========================================================
998
+
999
+ spans = spans.unsqueeze(2) # (B, N, 1, 2)
1000
+
1001
+ expanded = spans + shifts.view(1, 1, M, 1)
1002
+ # (B, N, M, 2)
1003
+
1004
+ expanded_s = expanded[..., 0]
1005
+ expanded_e = expanded[..., 1]
1006
+
1007
+ # =========================================================
1008
+ # valid shift range
1009
+ # =========================================================
1010
+
1011
+ valid_shift = (
1012
+ shifts.view(1, 1, M).abs()
1013
+ <= r.unsqueeze(-1)
1014
+ )
1015
+
1016
+ # =========================================================
1017
+ # original span != padding
1018
+ # =========================================================
1019
+
1020
+ non_pad_span = (
1021
+ spans.squeeze(2).sum(dim=-1) != 0
1022
+ ) # (B, N)
1023
+
1024
+ # =========================================================
1025
+ # boundary check
1026
+ # =========================================================
1027
+
1028
+ valid_length = attention_mask.sum(dim=-1)
1029
+ valid_length = valid_length.view(B, 1, 1)
1030
+
1031
+ valid_boundary = (
1032
+ (expanded_s > 0)
1033
+ &
1034
+ (expanded_e < valid_length)
1035
+ )
1036
+
1037
+ # =========================================================
1038
+ # final mask
1039
+ # =========================================================
1040
+
1041
+ valid_mask = (
1042
+ valid_shift
1043
+ &
1044
+ non_pad_span.unsqueeze(-1)
1045
+ &
1046
+ valid_boundary
1047
+ )
1048
+
1049
+ # =========================================================
1050
+ # invalid -> (0,0)
1051
+ # =========================================================
1052
+
1053
+ expanded = torch.where(
1054
+ valid_mask.unsqueeze(-1),
1055
+ expanded,
1056
+ torch.zeros_like(expanded)
1057
+ )
1058
+
1059
+ return expanded
1060
+
1061
+ def get_span_reprs(hidden, spans):
1062
+ """
1063
+ Args:
1064
+ hidden: (B, L, H)
1065
+ spans: (B, N, M, 2)
1066
+
1067
+ Return:
1068
+ span_repr: (B, N, M, 4H)
1069
+
1070
+ Nếu N = 0:
1071
+ return shape (B, 0, 0, 4H)
1072
+ """
1073
+
1074
+ B, N, M, _ = spans.shape
1075
+ H = hidden.size(-1)
1076
+
1077
+ # =========================================================
1078
+ # empty case
1079
+ # =========================================================
1080
+
1081
+ if N == 0:
1082
+ return torch.empty(
1083
+ (B, 0, 0, 4 * H),
1084
+ dtype=hidden.dtype,
1085
+ device=hidden.device
1086
+ )
1087
+
1088
+ batch_idx = torch.arange(
1089
+ B,
1090
+ device=hidden.device
1091
+ ).view(B, 1, 1)
1092
+
1093
+ start_idx = spans[..., 0] # (B, N, M)
1094
+ end_idx = spans[..., 1] # (B, N, M)
1095
+
1096
+ # =========================================================
1097
+ # gather hidden
1098
+ # =========================================================
1099
+
1100
+ start_h = hidden[batch_idx, start_idx]
1101
+ end_h = hidden[batch_idx, end_idx]
1102
+
1103
+ # (B, N, M, H)
1104
+
1105
+ # =========================================================
1106
+ # span representation
1107
+ # =========================================================
1108
+
1109
+ span_repr = torch.cat(
1110
+ [
1111
+ start_h,
1112
+ end_h,
1113
+ end_h - start_h,
1114
+ end_h * start_h
1115
+ ],
1116
+ dim=-1
1117
+ ) # (B, N, M, 4H)
1118
+
1119
+ return span_repr
1120
+
1121
+ class MLP(nn.Module):
1122
+ def __init__(self, in_size, hid_size, out_size):
1123
+ super().__init__()
1124
+ self.mlp = nn.Sequential(
1125
+ nn.Linear(in_size, hid_size),
1126
+ nn.ReLU(),
1127
+ nn.Linear(hid_size, out_size)
1128
+ )
1129
+
1130
+ def forward(self, x):
1131
+ return self.mlp(x)
1132
+
1133
+ class IEModel(nn.Module):
1134
+ def __init__(self, backbone_model_name, num_labels, max_r):
1135
+ super().__init__()
1136
+ self.max_r = max_r
1137
+
1138
+ self.encoder = AutoModel.from_pretrained(backbone_model_name)
1139
+ hidden_size = self.encoder.config.hidden_size
1140
+
1141
+ self.start_classifier = MLP(hidden_size, hidden_size, num_labels)
1142
+ self.end_classifier = MLP(hidden_size, hidden_size, num_labels)
1143
+
1144
+ self.span_scorer = MLP(4*hidden_size, hidden_size, 1)
1145
+
1146
+ def encode(self, input_ids, attention_mask):
1147
+ B, n_parts, L = input_ids.shape
1148
+ input_ids = input_ids.view(-1, L)
1149
+ attention_mask = attention_mask.view(-1, L)
1150
+
1151
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
1152
+ hidden_states = outputs.last_hidden_state # B * n_parts, L, H
1153
+
1154
+ hidden_states = hidden_states.view(B, n_parts, L, -1).reshape(B, n_parts*L, -1) # B, L, H
1155
+ attention_mask = attention_mask.view(B, n_parts, L).reshape(B, n_parts*L) # B, L
1156
+ return hidden_states, attention_mask
1157
+
1158
+ def get_logits(self, hidden_states):
1159
+ start_logits = self.start_classifier(hidden_states) # B, N, classes
1160
+ end_logits = self.end_classifier(hidden_states) # B, N, classes
1161
+ return start_logits, end_logits
1162
+
1163
+ def get_scores(self, expanded_span_reprs):
1164
+ return self.span_scorer(expanded_span_reprs).squeeze(-1)
1165
+
1166
+ def forward(self, input_ids, attention_mask, spans=None):
1167
+ hidden_states, attention_mask = self.encode(input_ids, attention_mask)
1168
+ start_logits, end_logits = self.get_logits(hidden_states)
1169
+
1170
+ if spans is None:
1171
+ spans, _ = extract_spans_and_labels(start_logits, end_logits) # (B, N, 2)
1172
+ B, N, _ = spans.shape
1173
+ r = torch.randint(1, self.max_r+1, (B, N), device=spans.device)
1174
+ expanded_spans = expand_spans(spans, r, attention_mask) # (B, N, M, 2)
1175
+
1176
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans) # (B, N, M, 4H)
1177
+ expanded_span_scores = self.get_scores(expanded_span_reprs) # (B, N, M)
1178
+
1179
+ return start_logits, end_logits, expanded_spans, expanded_span_scores
1180
+
1181
+ def test_model():
1182
+ model = DataParallelProxy(IEModel(backbone_model_name, 7, 6)).to(device)
1183
+ model.eval()
1184
+ total_params = sum(p.numel() for p in model.parameters())
1185
+ print(f"Total params: {total_params:,}")
1186
+
1187
+ vocab_size = model.module.encoder.config.vocab_size
1188
+ max_len = model.module.encoder.config.max_position_embeddings
1189
+
1190
+ bz = 32
1191
+ i = torch.randint(0, vocab_size, (bz, 5, 10)).to(device)
1192
+ a = torch.ones(bz, 5, 10).to(device)
1193
+ g = torch.ones(bz, 3, 2, dtype=torch.long).to(device)
1194
+
1195
+ with torch.no_grad():
1196
+ r = model(i, a)
1197
+
1198
+ if type(r) == tuple:
1199
+ print([r[i].shape if type(r[i]) == type(torch.Tensor()) else len(r[i]) for i in range(len(r))])
1200
+ else:
1201
+ print(r.shape)
1202
+
1203
+ test_model()
1204
+
1205
+ # %% [code]
1206
+ def configure_optimizers(network, optim_params, scheduler_params):
1207
+ try:
1208
+ optim_params = copy.copy(optim_params)
1209
+ scheduler_params = copy.copy(scheduler_params)
1210
+
1211
+ optim_name = optim_params.pop('name')
1212
+ scheduler_name = scheduler_params.pop('name')
1213
+
1214
+ optimizer_cls = globals().get(optim_name) or getattr(optim, optim_name, None)
1215
+ scheduler_cls = globals().get(scheduler_name) or getattr(optim.lr_scheduler, scheduler_name, None)
1216
+
1217
+ if optimizer_cls is None:
1218
+ raise ValueError(f"Optimizer '{optim_name}' is not available!")
1219
+
1220
+ optimizer = optimizer_cls(network.parameters(), **optim_params)
1221
+
1222
+ scheduler = None
1223
+ if scheduler_params and scheduler_cls: # Chỉ tạo scheduler nếu có tham số
1224
+ scheduler = scheduler_cls(optimizer, **scheduler_params)
1225
+
1226
+ return optimizer, scheduler
1227
+
1228
+ except KeyError as e:
1229
+ raise ValueError(f"Missing {e} in config!!")
1230
+
1231
+ def freeze(self, model):
1232
+ model.eval()
1233
+ for param in model.parameters():
1234
+ param.requires_grad = False
1235
+
1236
+ def unfreeze(self, model):
1237
+ model.train()
1238
+ for param in model.parameters():
1239
+ param.requires_grad = True
1240
+
1241
+ def reduce_batch_size(loader, ratio=0.5):
1242
+ new_bs = max(1, int(loader.batch_size * ratio))
1243
+
1244
+ shuffle = isinstance(loader.sampler, RandomSampler)
1245
+
1246
+ new_loader = DataLoader(
1247
+ dataset=loader.dataset,
1248
+ batch_size=new_bs,
1249
+ shuffle=shuffle,
1250
+ sampler=None if shuffle else loader.sampler,
1251
+ num_workers=loader.num_workers,
1252
+ collate_fn=loader.collate_fn,
1253
+ pin_memory=loader.pin_memory,
1254
+ drop_last=loader.drop_last,
1255
+ timeout=loader.timeout,
1256
+ worker_init_fn=loader.worker_init_fn,
1257
+ multiprocessing_context=loader.multiprocessing_context,
1258
+ generator=loader.generator,
1259
+ prefetch_factor=loader.prefetch_factor if loader.num_workers > 0 else None,
1260
+ persistent_workers=loader.persistent_workers,
1261
+ pin_memory_device=loader.pin_memory_device
1262
+ )
1263
+
1264
+ return new_loader
1265
+
1266
+ def list_to_tuple(x):
1267
+ if isinstance(x, (list, tuple)):
1268
+ return tuple(list_to_tuple(i) for i in x)
1269
+ return x
1270
+
1271
+ def fmt(x):
1272
+ if isinstance(x, float):
1273
+ return round(x, 5)
1274
+ if isinstance(x, dict):
1275
+ return {k: fmt(v) for k, v in x.items()}
1276
+ if isinstance(x, list):
1277
+ return [fmt(v) for v in x]
1278
+ return x
1279
+
1280
+ class ModelEmaV3Proxy(ModelEmaV3):
1281
+ def __getattr__(self, name):
1282
+ try:
1283
+ return super().__getattr__(name)
1284
+ except AttributeError:
1285
+ return getattr(self.module, name)
1286
+
1287
+ def align(spans, gold_spans, gold_labels):
1288
+ """
1289
+ Args:
1290
+ spans: (B, N1, M, 2)
1291
+ gold_spans: (B, N2, 2)
1292
+ gold_labels: (B, N2)
1293
+
1294
+ Returns:
1295
+ labels: (B, N1, M)
1296
+
1297
+ matched -> gold_labels
1298
+ padding (0,0) -> -100
1299
+ unmatched -> 0
1300
+ """
1301
+
1302
+ device = spans.device
1303
+
1304
+ B, N1, M, _ = spans.shape
1305
+ N2 = gold_spans.shape[1]
1306
+
1307
+ # =========================================================
1308
+ # compare spans
1309
+ # =========================================================
1310
+
1311
+ pred = spans.unsqueeze(3)
1312
+ # (B, N1, M, 1, 2)
1313
+
1314
+ gold = gold_spans.unsqueeze(1).unsqueeze(1)
1315
+ # (B, 1, 1, N2, 2)
1316
+
1317
+ matched = (pred == gold).all(dim=-1)
1318
+ # (B, N1, M, N2)
1319
+
1320
+ # =========================================================
1321
+ # default labels = 0
1322
+ # =========================================================
1323
+
1324
+ labels = torch.zeros(
1325
+ (B, N1, M),
1326
+ dtype=gold_labels.dtype,
1327
+ device=device
1328
+ )
1329
+
1330
+ # =========================================================
1331
+ # assign matched labels
1332
+ # =========================================================
1333
+
1334
+ has_match = matched.any(dim=-1)
1335
+ # (B, N1, M)
1336
+
1337
+ matched_idx = matched.float().argmax(dim=-1)
1338
+ # (B, N1, M)
1339
+
1340
+ expanded_gold_labels = gold_labels.unsqueeze(1).unsqueeze(1).expand(
1341
+ B, N1, M, N2
1342
+ )
1343
+
1344
+ gathered_labels = torch.gather(
1345
+ expanded_gold_labels,
1346
+ dim=-1,
1347
+ index=matched_idx.unsqueeze(-1)
1348
+ ).squeeze(-1)
1349
+
1350
+ labels = torch.where(
1351
+ has_match,
1352
+ gathered_labels,
1353
+ labels
1354
+ )
1355
+
1356
+ # =========================================================
1357
+ # padding span -> -100
1358
+ # =========================================================
1359
+
1360
+ is_padding = (spans.sum(dim=-1) == 0)
1361
+
1362
+ labels = torch.where(
1363
+ is_padding,
1364
+ torch.full_like(labels, -100),
1365
+ labels
1366
+ )
1367
+
1368
+ return labels
1369
+
1370
+ def select_best_spans(spans, span_scores):
1371
+ """
1372
+ Args:
1373
+ spans: (B, N, M, 2)
1374
+ span_scores: (B, N, M)
1375
+
1376
+ Returns:
1377
+ best_spans: (B, N, 2)
1378
+ """
1379
+
1380
+ # =========================================================
1381
+ # padding spans -> -inf
1382
+ # =========================================================
1383
+
1384
+ is_padding = (spans.sum(dim=-1) == 0)
1385
+ # (B, N, M)
1386
+
1387
+ masked_scores = span_scores.masked_fill(
1388
+ is_padding,
1389
+ float("-inf")
1390
+ )
1391
+
1392
+ # =========================================================
1393
+ # best index theo chiều M
1394
+ # =========================================================
1395
+
1396
+ best_idx = masked_scores.argmax(dim=-1)
1397
+ # (B, N)
1398
+
1399
+ # =========================================================
1400
+ # gather spans
1401
+ # =========================================================
1402
+
1403
+ best_spans = torch.gather(
1404
+ spans,
1405
+ dim=2,
1406
+ index=best_idx.unsqueeze(-1).unsqueeze(-1).expand(
1407
+ -1, -1, 1, 2
1408
+ )
1409
+ ).squeeze(2)
1410
+
1411
+ # =========================================================
1412
+ # nếu toàn bộ M đều là padding
1413
+ # -> output = (0,0)
1414
+ # =========================================================
1415
+
1416
+ all_padding = is_padding.all(dim=-1)
1417
+ # (B, N)
1418
+
1419
+ best_spans = torch.where(
1420
+ all_padding.unsqueeze(-1),
1421
+ torch.zeros_like(best_spans),
1422
+ best_spans
1423
+ )
1424
+
1425
+ return best_spans
1426
+
1427
+ def extract_entities(input_ids, spans, labels, id2label):
1428
+ """
1429
+ Args:
1430
+ input_ids: (B, L)
1431
+ spans: (B, N, 2)
1432
+ labels: (B, N)
1433
+
1434
+ Returns:
1435
+ List[
1436
+ (
1437
+ bidx,
1438
+ (
1439
+ tuple(token_ids),
1440
+ label_name
1441
+ )
1442
+ )
1443
+ ]
1444
+ """
1445
+
1446
+ B, N, _ = spans.shape
1447
+
1448
+ results = []
1449
+
1450
+ for bidx in range(B):
1451
+
1452
+ for n in range(N):
1453
+
1454
+ lb = labels[bidx, n].item()
1455
+
1456
+ # ignore padding / negative
1457
+ if lb <= 0:
1458
+ continue
1459
+
1460
+ s, e = spans[bidx, n].tolist()
1461
+
1462
+ # skip padding span
1463
+ if s == 0 and e == 0:
1464
+ continue
1465
+
1466
+ token_ids = tuple(
1467
+ input_ids[bidx, s:e + 1].tolist()
1468
+ )
1469
+
1470
+ results.append(
1471
+ (
1472
+ bidx,
1473
+ (
1474
+ token_ids,
1475
+ id2label[lb]
1476
+ )
1477
+ )
1478
+ )
1479
+
1480
+ return results
1481
+
1482
+ class Trainer:
1483
+ def __init__(
1484
+ self, training_time="00:11:30:00", eval_mode="max", topk=1, save_name="network", save_best=True, save_last=False, max_grad_norm=200.0,
1485
+ logging=0, logging_file=False, checkpoints_dir="", early_stopping=False, eval_from_ratio=-1, eval_every=1, device='cpu',
1486
+ schedule_in_step=True, use_ema=True, ema_from_ratio=-1, ema_decay=0.999, return_best=True, return_last=True
1487
+ ):
1488
+ self.ema_net = None
1489
+
1490
+ self.training_time = self._time_str_to_seconds(training_time)
1491
+ self.mode = eval_mode
1492
+ self.topk = topk
1493
+ self.device = device
1494
+ self.logging = logging if logging < epochs else 1
1495
+ self.logging_file = logging_file
1496
+ self.checkpoints_dir = checkpoints_dir
1497
+ self.early_stopping = early_stopping
1498
+ self.eval_from_ratio = eval_from_ratio
1499
+ self.eval_every = eval_every
1500
+ self.save_name = save_name
1501
+ self.save_best = save_best
1502
+ self.save_last = save_last
1503
+ self.return_best = return_best
1504
+ self.return_last = return_last
1505
+ self.max_grad_norm = max_grad_norm
1506
+ self.schedule_in_step = schedule_in_step
1507
+ self.use_ema = use_ema
1508
+ self.ema_from_ratio = ema_from_ratio
1509
+ self.ema_decay = ema_decay
1510
+
1511
+ self.best_stage = [[float('-inf') if self.mode == 'max' else float('inf'), None, None]]
1512
+ self.grad_scaler = torch.amp.GradScaler(self.device, init_scale=1024.0)
1513
+
1514
+ def fit(self, network, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader=None, eval_fn=None, start_epoch=1, start_training_time=None, id2label=None):
1515
+ if eval_fn is None:
1516
+ if self.mode == "max":
1517
+ eval_fn = lambda *x: -loss_fn(*x)
1518
+ else:
1519
+ eval_fn = lambda *x: loss_fn(*x)
1520
+
1521
+ if torch.cuda.device_count() > 1:
1522
+ network = DataParallelProxy(network)
1523
+ network = network.to(self.device)
1524
+
1525
+ if not start_training_time:
1526
+ start_training_time = time.time()
1527
+
1528
+ start_ema = int(epochs * self.ema_from_ratio)
1529
+ start_eval = int(epochs * self.eval_from_ratio)
1530
+
1531
+ if val_loader is None:
1532
+ print(f'[Trainer CallBack] 📢 Không có Val Set, không thể đánh giá và Early Stopping!')
1533
+ else:
1534
+ model_to_use_str = 'mô hình EMA' if self.use_ema else 'mô hình gốc'
1535
+ start_model_update_str = f'Bắt đầu cập nhật EMA từ epoch {start_epoch + start_ema}!' if self.use_ema else ''
1536
+ print(f'[Trainer CallBack] 📢 Đánh giá bằng {model_to_use_str} từ epoch {start_epoch + start_eval}!', start_model_update_str)
1537
+
1538
+ training_log = {}
1539
+ for epoch in range(start_epoch, epochs+start_epoch):
1540
+ if self.use_ema and self.ema_net is None and epoch - start_epoch >= start_ema:
1541
+ self.ema_net = ModelEmaV3Proxy(network, self.ema_decay, device=self.device)
1542
+
1543
+ try:
1544
+ teaching_rate = math.cos(math.pi / 2 * epoch / epochs)
1545
+ train_loss_epoch, train_loss_epoch_dict = self._train_epoch(network, train_loader, optimizer, scheduler, loss_fn, teaching_rate)
1546
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': train_loss_epoch}
1547
+ logging_dict.update(train_loss_epoch_dict)
1548
+
1549
+ if val_loader is not None and epoch - start_epoch >= start_eval and (epoch - start_epoch - start_eval) % self.eval_every == 0:
1550
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1551
+
1552
+ val_score, val_score_dict, _ = self._eval_epoch(eval_net, val_loader, eval_fn, id2label)
1553
+ update = self._update_best_network(eval_net, val_score, epoch)
1554
+ logging_dict.update({'val_score': val_score, 'best_score': self.best_stage[0][0], 'new_best_model': update})
1555
+ logging_dict.update(val_score_dict)
1556
+ if not self.schedule_in_step and scheduler:
1557
+ scheduler.step()
1558
+
1559
+ except RuntimeError as e:
1560
+ if "out of memory" in str(e).lower():
1561
+ print(f"[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: CUDA Out of Memory! Clearing GPU cache...")
1562
+ torch.cuda.empty_cache()
1563
+ gc.collect()
1564
+ if torch.cuda.is_available():
1565
+ torch.cuda.synchronize()
1566
+ print(f"[Trainer CallBack] ✅ Epoch {epoch}/{epochs}: GPU memory cleared")
1567
+
1568
+ train_loader = reduce_batch_size(train_loader, ratio=0.5)
1569
+ if val_loader is not None:
1570
+ val_loader = reduce_batch_size(val_loader, ratio=0.5)
1571
+
1572
+ logging_dict = {'lr': [group['lr'] for group in optimizer.param_groups], 'train_loss': float('inf')}
1573
+ else:
1574
+ raise
1575
+
1576
+ training_log[epoch] = logging_dict
1577
+ if self.is_early_stopping(epoch):
1578
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}: Detect Overfitting! Breaking Training Process...')
1579
+ break
1580
+ if self.logging:
1581
+ if epoch % self.logging == 0:
1582
+ print(f'[Trainer CallBack] 📢 Epoch {epoch}/{epochs}:', fmt(logging_dict))
1583
+ else:
1584
+ print(f'{epoch}...', end=' ')
1585
+
1586
+ if self._at_time_limit(start_training_time):
1587
+ print(f'[Trainer CallBack] ⚠️ Epoch {epoch}/{epochs}: Thời gian training giới hạn là {self.training_time}, hết giờ tại epoch {epoch}/{epochs}')
1588
+ break
1589
+
1590
+ if self.logging_file:
1591
+ os.makedirs(f'{self.checkpoints_dir}/logs', exist_ok=True)
1592
+ with open(f"{self.checkpoints_dir}/logs/{self.save_name}_logging.json", "a", encoding="utf-8") as f:
1593
+ f.write(json.dumps(training_log))
1594
+
1595
+ if self.use_ema and self.ema_net is not None:
1596
+ self._save_state_dict(self.ema_net.module)
1597
+ else:
1598
+ self._save_state_dict(network)
1599
+ print(f'[Trainer CallBack] 📢 Kết thúc training.\n')
1600
+
1601
+ best_model, last_model = None, None
1602
+ eval_net = self.ema_net.module if (self.use_ema and self.ema_net is not None) else network
1603
+ if self.return_best :
1604
+ best_model = self.best_stage[0][2] if self.best_stage[0][2] is not None else eval_net.state_dict()
1605
+ best_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in best_model.items()}
1606
+ if self.return_last:
1607
+ last_model = eval_net.state_dict()
1608
+ last_model = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in last_model.items()}
1609
+
1610
+ del network
1611
+ torch.cuda.empty_cache()
1612
+ gc.collect()
1613
+ return training_log, best_model, last_model
1614
+
1615
+ def _time_str_to_seconds(self, time_str):
1616
+ days, hours, minutes, seconds = map(int, time_str.split(":"))
1617
+ return days * 86400 + hours * 3600 + minutes * 60 + seconds
1618
+
1619
+ def _update_best_network(self, network, val_score, epoch):
1620
+ topk = max(1, self.topk)
1621
+ self.best_stage.append([val_score, epoch, {k: v.detach().cpu().clone() for k, v in network.state_dict().items()}])
1622
+ self.best_stage = sorted(self.best_stage, reverse=(self.mode == 'max'), key=lambda x: x[0])[:topk]
1623
+ if val_score in [x[0] for x in self.best_stage]:
1624
+ return True
1625
+ return False
1626
+
1627
+ def is_early_stopping(self, epoch):
1628
+ if self.best_stage[0][1] is None:
1629
+ return False
1630
+ if not self.early_stopping:
1631
+ return False
1632
+ return epoch - self.best_stage[0][1] >= self.early_stopping
1633
+
1634
+ def _at_time_limit(self, start_training_time):
1635
+ return time.time() - start_training_time >= self.training_time
1636
+
1637
+ def _save_state_dict(self, network):
1638
+ if self.topk <= 0:
1639
+ return
1640
+
1641
+ if self.save_best:
1642
+ for r in range(self.topk):
1643
+ os.makedirs(f'{self.checkpoints_dir}/r{r+1}s', exist_ok=True)
1644
+
1645
+ for rank, (score, epoch, state_dict) in enumerate(self.best_stage):
1646
+ if state_dict is None:
1647
+ continue
1648
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in state_dict.items()}
1649
+ torch.save(state_dict, f'{self.checkpoints_dir}/r{rank+1}s/{self.save_name}_r{rank+1}_vs{score:.5f}_{"ema" if self.ema_net is not None else ""}.pth')
1650
+ if self.save_last:
1651
+ os.makedirs(f'{self.checkpoints_dir}/lasts', exist_ok=True)
1652
+ state_dict = {k.replace("module.", ""): v.detach().cpu().clone() for k, v in network.state_dict().items()}
1653
+ torch.save(state_dict, f'{self.checkpoints_dir}/lasts/{self.save_name}_last_{"ema" if self.ema_net is not None else ""}.pth')
1654
+
1655
+ def _train_epoch(self, network, train_loader, optimizer, scheduler, loss_fn, teaching_rate):
1656
+ network.train()
1657
+ total_loss = 0
1658
+ total_loss_dict = {}
1659
+ for batch_idx, batch in enumerate(train_loader):
1660
+ optimizer.zero_grad()
1661
+ with torch.autocast(device_type=self.device, dtype=torch.float16):
1662
+ loss, loss_dict = self._cal_loss(network, batch, batch_idx, loss_fn, teaching_rate)
1663
+
1664
+ for k, v in loss_dict.items():
1665
+ t = total_loss_dict.get(k, 0)
1666
+ total_loss_dict[k] = t + v
1667
+ self.grad_scaler.scale(loss).backward()
1668
+ self.grad_scaler.unscale_(optimizer)
1669
+ grad_norm = nn.utils.clip_grad_norm_(network.parameters(), self.max_grad_norm)
1670
+ # print(grad_norm) # Bỏ cmt dòng này để biết nên chọn max_grad_norm bằng bao nhiêu...
1671
+ self.grad_scaler.step(optimizer)
1672
+ self.grad_scaler.update()
1673
+ if self.schedule_in_step and scheduler:
1674
+ scheduler.step()
1675
+ if self.use_ema and self.ema_net is not None:
1676
+ self.ema_net.update(network)
1677
+ total_loss += loss
1678
+ return (total_loss / len(train_loader)).item(), {k: v.item() / len(train_loader) for k, v in total_loss_dict.items()}
1679
+
1680
+ def _eval_epoch(self, network, val_loader, eval_fn, id2label):
1681
+ network.eval()
1682
+ total_score = 0.0
1683
+ total_score_dict = {}
1684
+ object_lists = None # sẽ init sau
1685
+
1686
+ with torch.no_grad():
1687
+ for batch_idx, batch in enumerate(val_loader):
1688
+ score, score_dict, objects = self._cal_val_score(network, batch, batch_idx, eval_fn, id2label)
1689
+ total_score += score
1690
+
1691
+ for k, v in score_dict.items():
1692
+ t = total_score_dict.get(k, 0)
1693
+ total_score_dict[k] = t + v
1694
+
1695
+ if objects:
1696
+ if object_lists is None:
1697
+ object_lists = [[] for _ in range(len(objects))]
1698
+
1699
+ for i, obj in enumerate(objects):
1700
+ object_lists[i].append(obj.detach())
1701
+
1702
+ if object_lists is not None:
1703
+ object_arrays = [
1704
+ torch.concat(obj_list, dim=0).cpu().numpy()
1705
+ for obj_list in object_lists
1706
+ ]
1707
+ else:
1708
+ object_arrays = []
1709
+
1710
+ return total_score / len(val_loader), {k: v / len(val_loader) for k, v in total_score_dict.items()}, object_arrays
1711
+
1712
+ def _cal_loss(self, network, batch, batch_idx, loss_fn, teaching_rate):
1713
+ # Bạn cần override _cal_loss để tính loss
1714
+ input_ids = batch['input_ids'].to(self.device)
1715
+ attention_mask = batch['attention_mask'].to(self.device)
1716
+ gold_spans = batch['gold_spans'].to(self.device)
1717
+ gold_labels = batch['gold_labels'].to(self.device)
1718
+ start_labels = batch['start_labels'].to(self.device)
1719
+ end_labels = batch['end_labels'].to(self.device)
1720
+
1721
+ choice = random.random()
1722
+ if choice < teaching_rate:
1723
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask, gold_spans)
1724
+ else:
1725
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1726
+
1727
+ labels = align(spans, gold_spans, gold_labels)
1728
+
1729
+ loss_dict = loss_fn(
1730
+ start_logits, start_labels,
1731
+ end_logits, end_labels,
1732
+ span_scores, labels
1733
+ )
1734
+ return loss_dict['total'], loss_dict
1735
+
1736
+ def _cal_val_score(self, network, batch, batch_idx, eval_fn, id2label):
1737
+ # Bạn cần override _cal_val_score để tính val score, list bên cạnh là để trả về y hay pred gì đó (nếu cần)
1738
+ input_ids = batch['input_ids'].to(self.device)
1739
+ attention_mask = batch['attention_mask'].to(self.device)
1740
+ gold_entities = batch['gold_entities']
1741
+
1742
+ B, _, _ = input_ids.shape
1743
+
1744
+ start_logits, end_logits, spans, span_scores = network(input_ids, attention_mask)
1745
+ _, labels = extract_spans_and_labels(start_logits, end_logits)
1746
+ spans = select_best_spans(spans, span_scores)
1747
+
1748
+ pred_ids = extract_entities(input_ids.reshape(B, -1), spans, labels, id2label)
1749
+ pred_ids = list_to_tuple(pred_ids)
1750
+
1751
+ gold_ids = list_to_tuple(gold_entities)
1752
+
1753
+ score_dict = eval_fn(pred_ids, gold_ids)
1754
+ return score_dict['f1'], score_dict, []
1755
+
1756
+ # %% [code]
1757
+ class PhoBERTSpanAligner:
1758
+ def __init__(self, tokenizer, max_len):
1759
+ self.tokenizer = tokenizer
1760
+ self.max_len = max_len
1761
+
1762
+ # ===== 1. Extract discontinuous spans =====
1763
+ def extract_spans(self, sample):
1764
+ entity_spans = []
1765
+
1766
+ for event in sample["entities"]:
1767
+ entity_type = event["label"]
1768
+ spans = [tuple(event["offset"])]
1769
+ entity_spans.append({
1770
+ "spans": spans,
1771
+ "label": entity_type
1772
+ })
1773
+
1774
+ return entity_spans
1775
+
1776
+ # ===== 2. Word offsets =====
1777
+ def build_word_offsets(self, text, words):
1778
+ offsets = []
1779
+ pointer = 0
1780
+
1781
+ for word in words:
1782
+ start = text.find(word, pointer)
1783
+ end = start + len(word)
1784
+ offsets.append((start, end))
1785
+ pointer = end
1786
+
1787
+ return offsets
1788
+
1789
+ # ===== 3. Char → word =====
1790
+ def char_span_to_word_span(self, word_offsets, start, end):
1791
+ start_word = None
1792
+ end_word = None
1793
+
1794
+ for i, (w_start, w_end) in enumerate(word_offsets):
1795
+ if w_start <= start < w_end:
1796
+ start_word = i
1797
+ if w_start < end <= w_end:
1798
+ end_word = i
1799
+
1800
+ return start_word, end_word
1801
+
1802
+ # ===== 4. Word → subword =====
1803
+ def word_to_subword_map(self, words):
1804
+ mapping = []
1805
+ subword_index = 1 # <s>
1806
+
1807
+ for word in words:
1808
+ sub_tokens = self.tokenizer.tokenize(word)
1809
+ start = subword_index
1810
+ end = subword_index + len(sub_tokens) - 1
1811
+ mapping.append((start, end))
1812
+ subword_index += len(sub_tokens)
1813
+
1814
+ return mapping
1815
+
1816
+ # ===== 5. Span → subword =====
1817
+ def span_to_subword(self, word_offsets, word_subword_map, spans):
1818
+ sub_spans = []
1819
+
1820
+ for span_start, span_end in spans:
1821
+ w_start, w_end = self.char_span_to_word_span(
1822
+ word_offsets, span_start, span_end
1823
+ )
1824
+ if w_start is None or w_end is None:
1825
+ continue
1826
+
1827
+ sub_start = word_subword_map[w_start][0]
1828
+ sub_end = word_subword_map[w_end][1]
1829
+ sub_spans.append((sub_start, sub_end))
1830
+
1831
+ return sub_spans
1832
+
1833
+ def extract_valid_spans(self, sub_spans):
1834
+ valid_spans = []
1835
+ for s, e in sub_spans:
1836
+ if s < 0 or e < 0 or s >= self.max_len or e >= self.max_len or s > e:
1837
+ continue
1838
+ valid_spans.append((s, e))
1839
+ return valid_spans
1840
+
1841
+ def encode(self, sample):
1842
+ text = sample["text"]
1843
+ entities = self.extract_spans(sample)
1844
+
1845
+ # ===== 1. Word tokenize =====
1846
+ words = word_tokenize(text)
1847
+ sentence = " ".join(words)
1848
+
1849
+ # ===== 2. Mapping =====
1850
+ word_offsets = self.build_word_offsets(text, words)
1851
+ word_subword_map = self.word_to_subword_map(words)
1852
+
1853
+ # ===== 3. Tokenize FULL =====
1854
+ encoding = self.tokenizer(
1855
+ sentence,
1856
+ max_length=self.max_len,
1857
+ truncation=True,
1858
+ padding="max_length",
1859
+ return_tensors="pt"
1860
+ )
1861
+ input_ids = encoding["input_ids"][0]
1862
+ attention_mask = encoding["attention_mask"][0]
1863
+
1864
+ # ===== 5. Convert spans =====
1865
+ entities_gold_spans = []
1866
+
1867
+ for ent in entities:
1868
+ label = ent["label"]
1869
+
1870
+ sub_spans = self.span_to_subword(
1871
+ word_offsets,
1872
+ word_subword_map,
1873
+ ent["spans"]
1874
+ )
1875
+ valid_spans = self.extract_valid_spans(sub_spans)
1876
+ if len(valid_spans) == 0:
1877
+ continue
1878
+ entities_gold_spans.append((tuple(valid_spans), label))
1879
+
1880
+ return {
1881
+ "input_ids": input_ids,
1882
+ "attention_mask": attention_mask,
1883
+ "entities_gold_spans": entities_gold_spans,
1884
+ }
1885
+
1886
+ def generate_spans(attention_mask, max_span_len):
1887
+ seq_len = attention_mask.sum().item() - 2
1888
+ spans = []
1889
+ for i in range(1, seq_len+1):
1890
+ for j in range(i, min(i+max_span_len, seq_len+1)):
1891
+ spans.append((i, j))
1892
+ return spans
1893
+
1894
+ def match_gold_labels(
1895
+ gold_spans, # (N, 2)
1896
+ gold_labels, # (N,)
1897
+ pred_spans, # (M, 2)
1898
+ default_label=-100
1899
+ ):
1900
+ """
1901
+ Return:
1902
+ pred_labels: (M,)
1903
+ """
1904
+
1905
+ pred_labels = torch.full(
1906
+ (pred_spans.size(0),),
1907
+ default_label,
1908
+ dtype=gold_labels.dtype,
1909
+ device=gold_labels.device
1910
+ )
1911
+ if gold_spans.size(0) == 0:
1912
+ return pred_labels
1913
+
1914
+ # (M, N)
1915
+ matched = (pred_spans[:, None, :] == gold_spans[None, :, :]).all(dim=-1)
1916
+ has_match = matched.any(dim=1)
1917
+
1918
+ # lấy index gold đầu tiên match
1919
+ gold_idx = matched.float().argmax(dim=1)
1920
+
1921
+ pred_labels[has_match] = gold_labels[gold_idx[has_match]]
1922
+
1923
+ return pred_labels
1924
+
1925
+ class KLTNDataset(Dataset):
1926
+ def __init__(self, all_data, using_idxes, label2id, tokenizer, max_len, max_n_parts):
1927
+ super().__init__()
1928
+ self.tokenizer = tokenizer
1929
+ self.aligner = PhoBERTSpanAligner(tokenizer, max_len*max_n_parts)
1930
+ self.all_data = all_data
1931
+ self.using_idxes = using_idxes
1932
+ self.label2id = label2id
1933
+ self.max_len = max_len
1934
+ self.max_n_parts = max_n_parts
1935
+
1936
+ def __len__(self):
1937
+ return len(self.using_idxes)
1938
+
1939
+ def __getitem__(self, idx):
1940
+ ridx = self.using_idxes[idx]
1941
+ sample = self.all_data[ridx]
1942
+ result = self.aligner.encode(sample)
1943
+
1944
+ input_ids = result["input_ids"].squeeze(0)
1945
+ attention_mask = result["attention_mask"].squeeze(0)
1946
+ entities_gold_spans = result["entities_gold_spans"]
1947
+
1948
+ all_spans = torch.tensor(generate_spans(attention_mask, 10))
1949
+ gold_spans = torch.tensor([spans[0] for spans, _ in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, 2, dtype=torch.long)
1950
+ gold_labels = torch.tensor([self.label2id[label] for _, label in entities_gold_spans], dtype=torch.long) if entities_gold_spans else torch.empty(0, dtype=torch.long)
1951
+ all_labels = match_gold_labels(
1952
+ gold_spans, # (N, 2)
1953
+ gold_labels, # (N,)
1954
+ all_spans, # (M, 2)
1955
+ default_label=0
1956
+ )
1957
+
1958
+ # Get label
1959
+ gold_entities = []
1960
+ start_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1961
+ end_labels = torch.ones_like(input_ids) * (1-attention_mask) * (-100)
1962
+ for spans, label in entities_gold_spans:
1963
+ s, e = spans[0]
1964
+
1965
+ start_labels[s] = self.label2id[f'{label}']
1966
+ end_labels[e] = self.label2id[f'{label}']
1967
+
1968
+ gold_entities.append((tuple(input_ids[s:e+1].tolist()), label))
1969
+
1970
+ input_ids = input_ids.reshape(self.max_n_parts, self.max_len)
1971
+ attention_mask = attention_mask.reshape(self.max_n_parts, self.max_len)
1972
+
1973
+ n_valid_parts = math.ceil(attention_mask.sum().item() / self.max_len)
1974
+ input_ids = input_ids[:n_valid_parts]
1975
+ attention_mask = attention_mask[:n_valid_parts]
1976
+ start_labels = start_labels[:n_valid_parts*self.max_len]
1977
+ end_labels = end_labels[:n_valid_parts*self.max_len]
1978
+
1979
+ return {
1980
+ "input_ids": input_ids,
1981
+ "attention_mask": attention_mask,
1982
+
1983
+ "all_spans": all_spans,
1984
+ "all_labels": all_labels,
1985
+
1986
+ "gold_spans": gold_spans,
1987
+ "gold_labels": gold_labels,
1988
+
1989
+ "start_labels": start_labels,
1990
+ "end_labels": end_labels,
1991
+
1992
+ "gold_entities": gold_entities,
1993
+ }
1994
+
1995
+ def _pad_batch(tensor_list, pad_value=0):
1996
+ """
1997
+ tensor_list: list of tensors
1998
+ mỗi tensor shape: (Nk, n_parts_i, max_len_i)
1999
+
2000
+ return:
2001
+ padded tensor shape: (B, max_Nk, max_n_parts, max_len)
2002
+ """
2003
+
2004
+ # lấy max toàn batch
2005
+ max_Nk = max(t.size(0) for t in tensor_list)
2006
+ max_n_parts = max(t.size(1) for t in tensor_list)
2007
+ max_len = max(t.size(2) for t in tensor_list)
2008
+
2009
+ padded = []
2010
+
2011
+ for t in tensor_list:
2012
+ Nk, n_parts_i, max_len_i = t.shape
2013
+
2014
+ # pad chiều n_parts và max_len trước
2015
+ if n_parts_i < max_n_parts or max_len_i < max_len:
2016
+ new_t = t.new_full(
2017
+ (Nk, max_n_parts, max_len),
2018
+ pad_value
2019
+ )
2020
+ new_t[:, :n_parts_i, :max_len_i] = t
2021
+ t = new_t
2022
+
2023
+ # pad chiều Nk
2024
+ if Nk < max_Nk:
2025
+ pad_tensor = t.new_full(
2026
+ (max_Nk - Nk, max_n_parts, max_len),
2027
+ pad_value
2028
+ )
2029
+ t = torch.cat([t, pad_tensor], dim=0)
2030
+
2031
+ padded.append(t)
2032
+
2033
+ return torch.stack(padded) # (B, max_Nk, max_n_parts, max_len)
2034
+
2035
+ def collate_fn(batch):
2036
+ gold_entities = []
2037
+ for bidx, b in enumerate(batch):
2038
+ for entity in b['gold_entities']:
2039
+ gold_entities.append([bidx, entity])
2040
+
2041
+ input_ids = [b["input_ids"].unsqueeze(-1) for b in batch]
2042
+ attention_mask = [b["attention_mask"].unsqueeze(-1) for b in batch]
2043
+
2044
+ all_spans = [b["all_spans"].unsqueeze(-1) for b in batch]
2045
+ all_labels = [b["all_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2046
+
2047
+ gold_spans = [b["gold_spans"].unsqueeze(-1) for b in batch]
2048
+ gold_labels = [b["gold_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2049
+
2050
+ start_labels = [b["start_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2051
+ end_labels = [b["end_labels"].unsqueeze(-1).unsqueeze(-1) for b in batch]
2052
+
2053
+ # pad theo Nk
2054
+ input_ids = _pad_batch(input_ids, pad_value=0).squeeze(-1)
2055
+ attention_mask = _pad_batch(attention_mask, pad_value=0).squeeze(-1)
2056
+
2057
+ all_spans = _pad_batch(all_spans, pad_value=0).squeeze(-1)
2058
+ all_labels = _pad_batch(all_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2059
+
2060
+ gold_spans = _pad_batch(gold_spans, pad_value=0).squeeze(-1)
2061
+ gold_labels = _pad_batch(gold_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2062
+
2063
+ start_labels = _pad_batch(start_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2064
+ end_labels = _pad_batch(end_labels, pad_value=-100).squeeze(-1).squeeze(-1)
2065
+
2066
+ return {
2067
+ "input_ids": input_ids,
2068
+ "attention_mask": attention_mask,
2069
+
2070
+ "all_spans": all_spans,
2071
+ "all_labels": all_labels,
2072
+
2073
+ "gold_spans": gold_spans,
2074
+ "gold_labels": gold_labels,
2075
+
2076
+ "start_labels": start_labels,
2077
+ "end_labels": end_labels,
2078
+
2079
+ "gold_entities": gold_entities,
2080
+ }
2081
+
2082
+ # %% [code]
2083
+ def shift_bidx(spans, batch_idx):
2084
+ shifted = []
2085
+ for bidx, ent in spans:
2086
+ new_bidx = bidx + batch_idx * batch_size
2087
+ shifted.append((new_bidx, ent))
2088
+ return shifted
2089
+
2090
+ def refactor_entities(entities, save_dict):
2091
+ i, c = [], []
2092
+ for bidx, (ids, lb) in entities:
2093
+ if (bidx, ids) not in i:
2094
+ i.append((bidx, ids))
2095
+
2096
+ if (bidx, (ids, lb)) not in c:
2097
+ c.append((bidx, (ids, lb)))
2098
+
2099
+ save_dict['Ent-I'].extend(i)
2100
+ save_dict['Ent-C'].extend(c)
2101
+
2102
+ def check_spans(
2103
+ all_spans, # (N, 2)
2104
+ all_labels, # (N,)
2105
+ ensemble_start_logits, # (L, C)
2106
+ ensemble_end_logits, # (L, C)
2107
+ expand_dict,
2108
+ max_expand=10
2109
+ ):
2110
+ """
2111
+ Check minimum expansion radius required to achieve recall = 1.
2112
+
2113
+ Args:
2114
+ all_spans: gold spans
2115
+ all_labels: gold labels
2116
+ ensemble_start_logits: (L, C)
2117
+ ensemble_end_logits: (L, C)
2118
+ expand_dict:
2119
+ dict like:
2120
+ {
2121
+ 0: count,
2122
+ 1: count,
2123
+ ...
2124
+ }
2125
+
2126
+ Return:
2127
+ matched_expand_k
2128
+ """
2129
+
2130
+ # =========================================================
2131
+ # Gold spans
2132
+ # =========================================================
2133
+
2134
+ gold_set = set()
2135
+
2136
+ valid = all_labels > 0
2137
+
2138
+ valid_idxes = valid.nonzero(as_tuple=False).squeeze(-1)
2139
+
2140
+ for idx in valid_idxes:
2141
+
2142
+ s = int(all_spans[idx, 0])
2143
+ e = int(all_spans[idx, 1])
2144
+
2145
+ gold_set.add((s, e))
2146
+
2147
+ # no gold
2148
+ if len(gold_set) == 0:
2149
+ expand_dict[0] += 1
2150
+ return 0
2151
+
2152
+ # =========================================================
2153
+ # Decode base spans
2154
+ # =========================================================
2155
+
2156
+ start_labels = ensemble_start_logits.argmax(dim=-1) # (L,)
2157
+ end_labels = ensemble_end_logits.argmax(dim=-1) # (L,)
2158
+
2159
+ L = start_labels.shape[0]
2160
+
2161
+ base_pred = []
2162
+
2163
+ used_start = set()
2164
+ used_end = set()
2165
+
2166
+ for s in range(L):
2167
+
2168
+ s_label = start_labels[s].item()
2169
+
2170
+ if s_label == 0:
2171
+ continue
2172
+
2173
+ if s in used_start:
2174
+ continue
2175
+
2176
+ nearest_e = None
2177
+
2178
+ for e in range(s, L):
2179
+
2180
+ if e in used_end:
2181
+ continue
2182
+
2183
+ e_label = end_labels[e].item()
2184
+
2185
+ if e_label == s_label:
2186
+ nearest_e = e
2187
+ break
2188
+
2189
+ if nearest_e is None:
2190
+ continue
2191
+
2192
+ used_start.add(s)
2193
+ used_end.add(nearest_e)
2194
+
2195
+ base_pred.append((s, nearest_e))
2196
+
2197
+ # =========================================================
2198
+ # Try expansion radius
2199
+ # =========================================================
2200
+
2201
+ for k in range(max_expand + 1):
2202
+
2203
+ pred_set = set()
2204
+
2205
+ for s, e in base_pred:
2206
+
2207
+ for ds in range(-k, k + 1):
2208
+ for de in range(-k, k + 1):
2209
+
2210
+ ns = s + ds
2211
+ ne = e + de
2212
+
2213
+ if ns > 0 and ne > 0 and ns <= ne:
2214
+ pred_set.add((ns, ne))
2215
+
2216
+ # recall = 1
2217
+ if gold_set.issubset(pred_set):
2218
+
2219
+ expand_dict[k] += 1
2220
+ return k
2221
+
2222
+ # =========================================================
2223
+ # cannot recover
2224
+ # =========================================================
2225
+
2226
+ expand_dict[-1] = expand_dict.get(-1, 0) + 1
2227
+
2228
+ return -1
2229
+
2230
+ def test(network, state_dicts, test_loader, eval_fn, analyzer, device, id2label, tokenizer):
2231
+ if torch.cuda.device_count() > 1:
2232
+ network = DataParallelProxy(network)
2233
+ network = network.to(device)
2234
+ network.eval()
2235
+
2236
+ eval_types = ['Ent-I', 'Ent-C']
2237
+
2238
+ all_pred = {eval_type: [] for eval_type in eval_types}
2239
+ all_gold = {eval_type: [] for eval_type in eval_types}
2240
+
2241
+ list_input_ids = []
2242
+ expand_dict = {i: 0 for i in range(13)}
2243
+
2244
+ with torch.no_grad():
2245
+ for batch_idx, batch in enumerate(test_loader):
2246
+ input_ids = batch['input_ids'].to(device)
2247
+ attention_mask = batch['attention_mask'].to(device)
2248
+ all_spans = batch['all_spans'].to(device)
2249
+ all_labels = batch['all_labels'].to(device)
2250
+ gold_entities = batch['gold_entities']
2251
+
2252
+ B, _, _ = input_ids.shape
2253
+ list_input_ids.extend(input_ids.reshape(B, -1).tolist())
2254
+
2255
+ list_hidden_states = []
2256
+ list_start_logits = []
2257
+ list_end_logits = []
2258
+ list_scores = []
2259
+ for sd in state_dicts:
2260
+ if torch.cuda.device_count() > 1:
2261
+ network.module.load_state_dict(sd)
2262
+ else:
2263
+ network.load_state_dict(sd)
2264
+
2265
+ hidden_states, attention_mask = network.encode(input_ids, attention_mask)
2266
+ start_logits, end_logits = network.get_logits(hidden_states)
2267
+ list_hidden_states.append(hidden_states)
2268
+ list_start_logits.append(start_logits)
2269
+ list_end_logits.append(end_logits)
2270
+
2271
+ ensemble_start_logits = torch.stack(list_start_logits, dim=0).mean(dim=0)
2272
+ ensemble_end_logits = torch.stack(list_end_logits, dim=0).mean(dim=0)
2273
+ spans, pred_labels = extract_spans_and_labels(ensemble_start_logits, ensemble_end_logits)
2274
+ B, N, _ = spans.shape
2275
+ r = torch.ones((B, N), device=spans.device, dtype=torch.long) * network.max_r
2276
+ expanded_spans = expand_spans(spans, r, attention_mask)
2277
+
2278
+ for sd, hidden_states in zip(state_dicts, list_hidden_states):
2279
+ if torch.cuda.device_count() > 1:
2280
+ network.module.load_state_dict(sd)
2281
+ else:
2282
+ network.load_state_dict(sd)
2283
+ expanded_span_reprs = get_span_reprs(hidden_states, expanded_spans)
2284
+ expanded_span_scores = network.get_scores(expanded_span_reprs)
2285
+ list_scores.append(expanded_span_scores)
2286
+
2287
+ ensemble_scores = torch.stack(list_scores, dim=0).mean(dim=0)
2288
+ pred_spans = select_best_spans(expanded_spans, ensemble_scores)
2289
+
2290
+ pred_entities = extract_entities(input_ids.reshape(B, -1), pred_spans, pred_labels, id2label)
2291
+ pred_entities = shift_bidx(pred_entities, batch_idx)
2292
+ refactor_entities(pred_entities, all_pred)
2293
+
2294
+ gold_entities = shift_bidx(gold_entities, batch_idx)
2295
+ refactor_entities(gold_entities, all_gold)
2296
+
2297
+ for b in range(B):
2298
+ check_spans(
2299
+ all_spans[b], # (N, 2)
2300
+ all_labels[b], # (N,)
2301
+ ensemble_start_logits[b], # (L, C)
2302
+ ensemble_end_logits[b], # (L, C)
2303
+ expand_dict,
2304
+ max_expand=10
2305
+ )
2306
+
2307
+ # ===== GLOBAL EVAL =====
2308
+ final_score = {}
2309
+ for eval_type in eval_types:
2310
+ score = eval_fn(list_to_tuple(all_pred[eval_type]), list_to_tuple(all_gold[eval_type]))
2311
+ final_score[eval_type] = score
2312
+
2313
+ analyze_result = analyzer.analyze(list_to_tuple(all_pred['Ent-I']), list_to_tuple(all_gold['Ent-I']))
2314
+
2315
+ # ===== PREDICT =====
2316
+ predictions = []
2317
+ for input_ids in list_input_ids:
2318
+ predictions.append([tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)])
2319
+ for bidx, (ids, lb) in all_pred['Ent-C']:
2320
+ predictions[bidx].append((tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True), lb))
2321
+
2322
+ return final_score, analyze_result, predictions, expand_dict
2323
+
2324
+ # %% [code]
2325
+ with open(f'{train_dir}/train.json', "r", encoding="utf-8") as f:
2326
+ data_train = json.load(f)
2327
+
2328
+ with open(f'{test_dir}/test.json', "r", encoding="utf-8") as f:
2329
+ data_test = json.load(f)
2330
+
2331
+ print('Train:', len(data_train))
2332
+ print('Test:', len(data_test))
2333
+
2334
+ # %% [code]
2335
+ entity_types = ['O'] + sorted(list(set([e['label'] for d in data_train + data_test for e in d['entities']])))
2336
+ # bio_entity_type = ['O'] + [f'{prefix}-{ent}' for ent in entity_types for prefix in ['B', 'I']]
2337
+ label2id = {l: i for i, l in enumerate(entity_types)}
2338
+ id2label = {i: l for l, i in label2id.items()}
2339
+
2340
+ # %% [code]
2341
+ zero_entities_idxes = []
2342
+ for idx, d in enumerate(data_train):
2343
+ if len(d['entities']) == 0:
2344
+ zero_entities_idxes.append(idx)
2345
+
2346
+ n_zero_entities_samples = len(zero_entities_idxes)
2347
+ n_has_entities_samples = len(data_train) - n_zero_entities_samples
2348
+
2349
+ random.seed(42)
2350
+ k = min(int(n_has_entities_samples * zero_entities_rate), len(zero_entities_idxes))
2351
+ sampled_zero_entities_idxes = random.sample(zero_entities_idxes, k)
2352
+
2353
+ new_data_train = []
2354
+ for idx, d in enumerate(data_train):
2355
+ if len(d['entities']) == 0:
2356
+ if idx in sampled_zero_entities_idxes:
2357
+ new_data_train.append(d)
2358
+ else:
2359
+ new_data_train.append(d)
2360
+ data_train = new_data_train
2361
+
2362
+ print('Train:', len(data_train))
2363
+
2364
+ # %% [code]
2365
+ if debug_only:
2366
+ data_train = data_train[:10]
2367
+ data_test = data_test[:10]
2368
+
2369
+ print('Train:', len(data_train))
2370
+ print('Test:', len(data_test))
2371
+
2372
+ # %% [code]
2373
+ tokenizer = AutoTokenizer.from_pretrained(backbone_model_name)
2374
+
2375
+ # %% [code]
2376
+ print('Experiment name:', state_dict_save_name)
2377
+
2378
+ # %% [code]
2379
+ if not test_only:
2380
+ full_idxes = np.array(range(len(data_train)))
2381
+ training_logs, best_models, last_models = [], [], []
2382
+ start_training_time = time.time()
2383
+ for seed in SEEDS:
2384
+ kf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
2385
+ for fold_idx, (tr_idx, va_idx) in enumerate(kf.split(full_idxes)):
2386
+ if only_fold_idx is not None and only_fold_idx >= 0 and only_fold_idx != fold_idx:
2387
+ continue
2388
+ set_seed(seed)
2389
+
2390
+ train_idxes, val_idxes = full_idxes[tr_idx], full_idxes[va_idx]
2391
+
2392
+ trainset = KLTNDataset(data_train, train_idxes, label2id, tokenizer, **train_memory_params)
2393
+ valset = KLTNDataset(data_train, val_idxes, label2id, tokenizer, **val_memory_params)
2394
+
2395
+ generator = torch.Generator()
2396
+ generator.manual_seed(seed)
2397
+ train_loader = DataLoader(trainset, generator=generator, collate_fn=collate_fn, **train_loader_params)
2398
+ val_loader = DataLoader(valset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2399
+
2400
+ my_model = IEModel(
2401
+ num_labels=len(label2id),
2402
+ **model_params
2403
+ )
2404
+ total_params = sum(p.numel() for p in my_model.parameters())
2405
+ print(f"Total params: {total_params:,}")
2406
+
2407
+ # optimizer, scheduler = configure_optimizers(my_model, optim_params, scheduler_params)
2408
+ encoder_params = set(map(id, my_model.encoder.parameters()))
2409
+ other_params = [
2410
+ p for p in my_model.parameters()
2411
+ if id(p) not in encoder_params
2412
+ ]
2413
+ optimizer = optim.AdamW([
2414
+ {"params": my_model.encoder.parameters(), "lr": 2e-5},
2415
+ {"params": other_params}
2416
+ ], lr=5e-4)
2417
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=1e-6)
2418
+
2419
+ loss_fn = CustomLoss(
2420
+ **loss_func_params
2421
+ )
2422
+ eval_fn = CustomEvalFn(**eval_func_params)
2423
+ trainer_params['save_name'] = f'{state_dict_save_name}_s{seed}_f{fold_idx}'
2424
+ trainer = Trainer(**trainer_params)
2425
+
2426
+ print(f'Start Training Fold {fold_idx}...')
2427
+ training_log, best_model, last_model = trainer.fit(
2428
+ my_model, optimizer, scheduler, loss_fn, epochs, train_loader, val_loader, eval_fn,
2429
+ start_epoch=1, start_training_time=start_training_time, id2label=id2label
2430
+ )
2431
+
2432
+ training_logs.append(training_log)
2433
+ best_models.append(best_model)
2434
+ last_models.append(last_model)
2435
+
2436
+ # %% [code]
2437
+ def load_all_state_dicts(folder):
2438
+ files = []
2439
+
2440
+ for file in os.listdir(folder):
2441
+ if file.endswith(".pt") or file.endswith(".pth"):
2442
+ m = re.search(r"f(\d+)", file) # tìm f<số>
2443
+ if m:
2444
+ fold = int(m.group(1))
2445
+ files.append((fold, file))
2446
+
2447
+ # sort theo fold
2448
+ files.sort(key=lambda x: x[0])
2449
+
2450
+ state_dicts = []
2451
+ for fold, file in files:
2452
+ path = os.path.join(folder, file)
2453
+ print(f"Loading fold {fold}: {file}")
2454
+ state_dict = torch.load(path, map_location="cpu")
2455
+ state_dicts.append(state_dict)
2456
+
2457
+ return state_dicts
2458
+
2459
+ if test_only:
2460
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=[f"{state_dict_save_name}/**"])
2461
+ get_ipython().system('rm -rf .cache .gitattributes')
2462
+
2463
+ best_models = load_all_state_dicts(f"{state_dict_save_name}/r1s")
2464
+ last_models = load_all_state_dicts(f"{state_dict_save_name}/lasts")
2465
+
2466
+ # %% [code]
2467
+ os.makedirs(f'{checkpoints_dir}/results', exist_ok=True)
2468
+ testset = KLTNDataset(data_test, range(len(data_test)), label2id, tokenizer, **val_memory_params)
2469
+ generator = torch.Generator()
2470
+ test_loader = DataLoader(testset, generator=generator, collate_fn=collate_fn, **val_loader_params)
2471
+ eval_fn = CustomEvalFn(**eval_func_params)
2472
+ analyzer = SpanErrorAnalyzer()
2473
+ my_model = IEModel(
2474
+ num_labels=len(label2id),
2475
+ **model_params
2476
+ )
2477
+ total_params = sum(p.numel() for p in my_model.parameters())
2478
+ print(f"Total params: {total_params:,}")
2479
+
2480
+ # %% [code]
2481
+ start_time = time.time()
2482
+ result_test = None
2483
+ analyze_result = None
2484
+
2485
+ best_score, best_analyze_result, best_pred_test, expand_dict = test(my_model, best_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2486
+ last_score, last_analyze_result, last_pred_test, _ = test(my_model, last_models, test_loader, eval_fn, analyzer, device, id2label, tokenizer)
2487
+
2488
+ result_test = {"Best model": best_score, "Last model": last_score}
2489
+ analyze_result = {"Best model": best_analyze_result, "Last model": last_analyze_result}
2490
+ analyze_result_sumary = {"Best model": best_analyze_result['summary'], "Last model": last_analyze_result['summary']}
2491
+ pred_test = {"Best model": best_pred_test, "Last model": last_pred_test}
2492
+
2493
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_test.json", "w", encoding="utf-8") as f:
2494
+ json.dump(result_test, f, ensure_ascii=False, indent=2)
2495
+
2496
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_error_analyze_result.json", "w", encoding="utf-8") as f:
2497
+ json.dump(analyze_result, f, ensure_ascii=False, indent=2)
2498
+
2499
+ with open(f"{checkpoints_dir}/results/{state_dict_save_name}_pred_test.json", "w", encoding="utf-8") as f:
2500
+ json.dump(pred_test, f, ensure_ascii=False, indent=2)
2501
+
2502
+ print('Test:', time.time() - start_time, 's --> Done!')
2503
+ print(json.dumps(analyze_result_sumary, ensure_ascii=False, indent=4))
2504
+
2505
+ # %% [code]
2506
+ expand_dict
2507
+
2508
+ # %% [code]
2509
+ expand_dict_sum = sum(list(expand_dict.values()))
2510
+ {key: value / expand_dict_sum for key, value in expand_dict.items()}
2511
+
2512
+ # %% [code]
2513
+ best_pred_test[:10]
2514
+
2515
+ # %% [code]
2516
+ last_pred_test[:10]
2517
+
2518
+ # %% [code]
2519
+ def dict_to_df(data):
2520
+ row_tuples = []
2521
+ row_values = []
2522
+
2523
+ metrics = ["precision", "recall", "f1"]
2524
+
2525
+ # Lấy model đầu tiên
2526
+ first_model = next(iter(data.values()))
2527
+
2528
+ # eval_keys
2529
+ eval_keys = list(first_model.keys())
2530
+
2531
+ for eval_key in eval_keys:
2532
+ row_tuples.append(eval_key)
2533
+ row = {}
2534
+
2535
+ for model_name, model_data in data.items():
2536
+ for metric in metrics:
2537
+ row[(model_name, metric)] = model_data[eval_key][metric]
2538
+
2539
+ row_values.append(row)
2540
+
2541
+ # ===== DataFrame =====
2542
+ df = pd.DataFrame(row_values)
2543
+
2544
+ # MultiIndex columns
2545
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
2546
+
2547
+ # Index
2548
+ df.index = pd.Index(row_tuples, name="evaluation")
2549
+
2550
+ # ===== Sort =====
2551
+ sort_keys = []
2552
+ if ("Best model", "f1") in df.columns:
2553
+ sort_keys.append(("Best model", "f1"))
2554
+ if ("Last model", "f1") in df.columns:
2555
+ sort_keys.append(("Last model", "f1"))
2556
+
2557
+ if sort_keys:
2558
+ df = df.sort_values(by=sort_keys, ascending=False)
2559
+
2560
+ return df
2561
+
2562
+ result_test_df = dict_to_df(result_test)
2563
+ result_test_df.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df.xlsx")
2564
+ result_test_df
2565
+
2566
+ # %% [code]
2567
+ key = ("Best model", "f1")
2568
+ result_test_df_best = result_test_df.sort_values(by=key, ascending=False).groupby(level="evaluation").head(1)
2569
+ result_test_df_best.to_excel(f"{checkpoints_dir}/results/{state_dict_save_name}_test_df_best.xlsx")
2570
+ result_test_df_best
2571
+
2572
+ # %% [code]
2573
+ def get_avg_best_score(logs):
2574
+ return float(np.mean([list(log.values())[-1]['best_score'] for log in logs]))
2575
+
2576
+ def get_avg_log(logs, epochs):
2577
+ avg_log = {}
2578
+
2579
+ for epoch in range(1, epochs + 1):
2580
+ val_score = 0.0
2581
+ train_loss = 0.0
2582
+ n_eval = 0
2583
+
2584
+ for idx in range(len(logs)):
2585
+ log = logs[idx].get(epoch, logs[idx].get(str(epoch)))
2586
+ if log is None:
2587
+ continue
2588
+
2589
+ val_score += log.get('val_score', 0.0)
2590
+ train_loss += log.get('train_loss', 0.0)
2591
+ n_eval += 1
2592
+
2593
+ if n_eval == 0:
2594
+ continue
2595
+
2596
+ avg_log[epoch] = {
2597
+ 'train_loss': train_loss / n_eval,
2598
+ 'val_score': val_score / n_eval if val_score != 0 else float('inf')
2599
+ }
2600
+
2601
+ return avg_log
2602
+
2603
+ def parse_label_key(label: str):
2604
+ try:
2605
+ first = float(label.split('_', 1)[0]) # số đầu: trước dấu _
2606
+ last = float(re.findall(r'_(\d+(?:\.\d+)?)$', label)[0])
2607
+ return first, last
2608
+ except:
2609
+ return (0, 0)
2610
+
2611
+ def plot_training_logs(logs_dict, save_path=None, figsize=(24, 10)):
2612
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
2613
+
2614
+ # ===== Plot Train Loss =====
2615
+ for name, log in logs_dict.items():
2616
+ epochs = sorted(log.keys())
2617
+ train_loss = [log[e]['train_loss'] for e in epochs]
2618
+ axes[0].plot(epochs, train_loss, label=name)
2619
+
2620
+ axes[0].set_xlabel('Epoch')
2621
+ axes[0].set_ylabel('Train Loss')
2622
+ axes[0].set_title('Training Loss')
2623
+ axes[0].grid(True)
2624
+
2625
+ # ===== Plot Validation Score =====
2626
+ for name, log in logs_dict.items():
2627
+ epochs = sorted(log.keys())
2628
+ val_score = [log[e]['val_score'] for e in epochs]
2629
+ axes[1].plot(epochs, val_score, label=name)
2630
+
2631
+ axes[1].set_xlabel('Epoch')
2632
+ axes[1].set_ylabel('Validation Score')
2633
+ axes[1].set_title('Validation Score')
2634
+ axes[1].grid(True)
2635
+
2636
+ # ===== Shared Legend =====
2637
+ handles, labels = axes[0].get_legend_handles_labels()
2638
+ pairs = list(zip(handles, labels))
2639
+ pairs_sorted = sorted(
2640
+ pairs,
2641
+ key=lambda x: parse_label_key(x[1])
2642
+ )
2643
+ handles_sorted, labels_sorted = zip(*pairs_sorted)
2644
+
2645
+ axes[0].legend(
2646
+ handles_sorted,
2647
+ labels_sorted,
2648
+ loc='center left',
2649
+ bbox_to_anchor=(1.01, 0.5),
2650
+ borderaxespad=0.
2651
+ )
2652
+
2653
+ plt.tight_layout(rect=[0, 0, 1, 1])
2654
+
2655
+ if save_path is not None:
2656
+ os.makedirs(os.path.dirname(save_path), exist_ok=True) if os.path.dirname(save_path) else None
2657
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
2658
+
2659
+ plt.show()
2660
+
2661
+ # %% [code]
2662
+ if not test_only:
2663
+ snapshot_download(repo_id=repo_name, local_dir="", repo_type="model", allow_patterns=["**/*entities*.json"])
2664
+ get_ipython().system('rm -rf .cache .gitattributes')
2665
+
2666
+ # %% [code]
2667
+ if not test_only:
2668
+ experiments = {}
2669
+ for experiment in os.listdir(pretrained_dir):
2670
+ if '.virtual_documents' in experiment:
2671
+ continue
2672
+ experiment_logs = []
2673
+ try:
2674
+ for seed in SEEDS:
2675
+ for fold_idx in range(nfolds):
2676
+ with open(f"{pretrained_dir}/{experiment}/logs/{experiment}_s{seed}_f{fold_idx}_logging.json", "r", encoding="utf-8") as f:
2677
+ experiment_log = json.load(f)
2678
+ experiment_logs.append(experiment_log)
2679
+ except:
2680
+ pass
2681
+ experiments[experiment] = get_avg_log(experiment_logs, 1000)
2682
+ experiments[state_dict_save_name] = get_avg_log(training_logs, 1000)
2683
+
2684
+ # %% [code]
2685
+ if not test_only:
2686
+ score = get_avg_best_score(training_logs)
2687
+ state_dict_save_name, score
2688
+
2689
+ # %% [code]
2690
+ if not test_only:
2691
+ plot_training_logs(experiments, save_path=f'{checkpoints_dir}/logs/{state_dict_save_name}_log_plot.jpg', figsize=(18, 7.5))
2692
+
4.2_add_span_rerank_branch_4.3/lasts/4.2_add_span_rerank_branch_4.3_s26092004_f0_last_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dbf61110eb38c4662d35ee2ef763bf443b2a9366bd52a0fbe7b7a75ed822011
3
+ size 554283006
4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_log_plot.jpg ADDED

Git LFS Details

  • SHA256: 014b76bb225bb83d2307a2b37e937886d3fe1ba7723d68e84693efbb6626139b
  • Pointer size: 131 Bytes
  • Size of remote file: 555 kB
4.2_add_span_rerank_branch_4.3/logs/4.2_add_span_rerank_branch_4.3_s26092004_f0_logging.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": {"lr": [2e-05, 0.0005], "train_loss": 0.071533203125, "total": 0.07154835103409726, "token_margin_loss": 0.05453465623253214, "span_loss": 0.017170905533817775, "start_margin": 0.02475195640022359, "end_margin": 0.02924119619899385}, "2": {"lr": [1.988303923565381e-05, 0.0004969282409784868], "train_loss": 0.0523681640625, "total": 0.052368641699273334, "token_margin_loss": 0.04108440469536054, "span_loss": 0.008921709055338178, "start_margin": 0.019319452207937394, "end_margin": 0.022306456120737842}, "3": {"lr": [1.9535036904803962e-05, 0.0004877886008156408], "train_loss": 0.0474853515625, "total": 0.04747764114030185, "token_margin_loss": 0.03835941866964785, "span_loss": 0.0066639882615986586, "start_margin": 0.01818404136389044, "end_margin": 0.02064700950251537}, "4": {"lr": [1.8964561979789496e-05, 0.00047280612778499774], "train_loss": 0.044342041015625, "total": 0.04433342649524874, "token_margin_loss": 0.03633314700950251, "span_loss": 0.0058386319172722195, "start_margin": 0.01763380380100615, "end_margin": 0.01938932364449413}, "5": {"lr": [1.8185661446562005e-05, 0.00045234974009654937], "train_loss": 0.04083251953125, "total": 0.04083985466741196, "token_margin_loss": 0.03509292901062046, "span_loss": 0.004480505869200671, "start_margin": 0.017170905533817775, "end_margin": 0.018323784237003912}, "6": {"lr": [1.7217514421272206e-05, 0.00042692314190604356], "train_loss": 0.03759765625, "total": 0.03759083286752376, "token_margin_loss": 0.03409726103968697, "span_loss": 0.003048141419787591, "start_margin": 0.016708007266629403, "end_margin": 0.017756078814980435}, "7": {"lr": [1.60839598967785e-05, 0.00039715242044697206], "train_loss": 0.0364990234375, "total": 0.036507825600894356, "token_margin_loss": 0.03346841811067636, "span_loss": 0.0028865637227501397, "start_margin": 0.01633244829513695, "end_margin": 0.017214575181665734}, "8": {"lr": [1.4812909747525698e-05, 0.00036377062968501693], "train_loss": 0.03515625, "total": 0.03514533258803801, "token_margin_loss": 0.032769703745109, "span_loss": 0.0023690783957518165, "start_margin": 0.016044228619340413, "end_margin": 0.01678661263275573, "val_score": 0.7112802091338476, "best_score": 0.7112802091338476, "new_best_model": true, "precision": 0.7939549476372776, "recall": 0.6466736660594441, "f1": 0.7112802091338476}, "9": {"lr": [1.3435661446562005e-05, 0.0003275997400965494], "train_loss": 0.033721923828125, "total": 0.03371296813862493, "token_margin_loss": 0.03191377864728899, "span_loss": 0.0018363086920067076, "start_margin": 0.015625, "end_margin": 0.016367384013415316, "val_score": 0.7125197294769945, "best_score": 0.7125197294769945, "new_best_model": true, "precision": 0.7903158777847237, "recall": 0.6510541411624837, "f1": 0.7125197294769945}, "10": {"lr": [1.1986127417882198e-05, 0.00028953039902753766], "train_loss": 0.0322265625, "total": 0.0322282001117943, "token_margin_loss": 0.030918110676355505, "span_loss": 0.0013166398826159867, "start_margin": 0.015214505310229179, "end_margin": 0.015843348239239798, "val_score": 0.7153066690348736, "best_score": 0.7153066690348736, "new_best_model": true, "precision": 0.789784007952825, "recall": 0.6561640838241907, "f1": 0.7153066690348736}, "11": {"lr": [1.0500000000000003e-05, 0.0002505], "train_loss": 0.0311431884765625, "total": 0.031145192845164895, "token_margin_loss": 0.030376607043040803, "span_loss": 0.0007669481903297932, "start_margin": 0.0148651481274455, "end_margin": 0.015493991056456121, "val_score": 0.7159717438966623, "best_score": 0.7159717438966623, "new_best_model": true, "precision": 0.7889888822678125, "recall": 0.6577450796077198, "f1": 0.7159717438966623}, "12": {"lr": [9.013872582117811e-06, 0.00021146960097246246], "train_loss": 0.0299224853515625, "total": 0.029922442705422023, "token_margin_loss": 0.02924119619899385, "span_loss": 0.0006353933761878145, "start_margin": 0.014384782001117943, "end_margin": 0.014935019564002236, "val_score": 0.7181912753106598, "best_score": 0.7181912753106598, "new_best_model": true, "precision": 0.7888571373177122, "recall": 0.6614039413276346, "f1": 0.7181912753106598}, "13": {"lr": [7.564338553438001e-06, 0.00017340025990345064], "train_loss": 0.0293426513671875, "total": 0.029346003353828955, "token_margin_loss": 0.02866475684740078, "span_loss": 0.000713998742314142, "start_margin": 0.014009223029625489, "end_margin": 0.014620598099496925, "val_score": 0.7186736409410447, "best_score": 0.7186736409410447, "new_best_model": true, "precision": 0.7887910418257643, "recall": 0.6622222187077492, "f1": 0.7186736409410447}, "14": {"lr": [6.1870902524743065e-06, 0.00013722937031498307], "train_loss": 0.0281524658203125, "total": 0.02815818893236445, "token_margin_loss": 0.027599217439910565, "span_loss": 0.0004702675202627166, "start_margin": 0.013555058692006707, "end_margin": 0.014035424818334264, "val_score": 0.7187089750512429, "best_score": 0.7187089750512429, "new_best_model": true, "precision": 0.7880049613591312, "recall": 0.6628539836653775, "f1": 0.7187089750512429}, "15": {"lr": [4.916040103221507e-06, 0.00010384757955302797], "train_loss": 0.027435302734375, "total": 0.02744200670765791, "token_margin_loss": 0.02714505310229178, "span_loss": 0.0003141485292062605, "start_margin": 0.013179499720514253, "end_margin": 0.01369480156512018, "val_score": 0.7191342289503211, "best_score": 0.7191342289503211, "new_best_model": true, "precision": 0.7863258914055251, "recall": 0.6647396711299539, "f1": 0.7191342289503211}, "16": {"lr": [3.7824855787278e-06, 7.40768580939564e-05], "train_loss": 0.0265350341796875, "total": 0.026533678032420346, "token_margin_loss": 0.026289128004471772, "span_loss": 0.0003215177822806037, "start_margin": 0.012838876467300168, "end_margin": 0.013441517607602012, "val_score": 0.7196401005954386, "best_score": 0.7196401005954386, "new_best_model": true, "precision": 0.7859884264169718, "recall": 0.6658980147713309, "f1": 0.7196401005954386}, "17": {"lr": [2.814338553438001e-06, 4.865025990345063e-05], "train_loss": 0.0258026123046875, "total": 0.02580002794857462, "token_margin_loss": 0.02569522079373952, "span_loss": 0.00011777158153996646, "start_margin": 0.012384712129681386, "end_margin": 0.013013555058692007, "val_score": 0.7184733712726745, "best_score": 0.7196401005954386, "new_best_model": false, "precision": 0.7837322202065794, "recall": 0.6654246339380059, "f1": 0.7184733712726745}, "18": {"lr": [2.0354380202105066e-06, 2.8193872215002235e-05], "train_loss": 0.0253753662109375, "total": 0.02538079932923421, "token_margin_loss": 0.025275992174399107, "span_loss": 9.600499144074902e-05, "start_margin": 0.012306106763555058, "end_margin": 0.0128039407490218, "val_score": 0.7172896282471793, "best_score": 0.7196401005954386, "new_best_model": false, "precision": 0.7822158957899207, "recall": 0.6644140740062107, "f1": 0.7172896282471793}}
4.2_add_span_rerank_branch_4.3/r1s/4.2_add_span_rerank_branch_4.3_s26092004_f0_r1_vs0.71964_ema.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a8bf7ec29dad2dea80c0a5e6f33d828a8d01f16956f69f693ea0f09716548a4
3
+ size 554284870
4.2_add_span_rerank_branch_4.3/results/4.2_add_span_rerank_branch_4.3_error_analyze_result.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_4.3/results/4.2_add_span_rerank_branch_4.3_pred_test.json ADDED
The diff for this file is too large to render. See raw diff
 
4.2_add_span_rerank_branch_4.3/results/4.2_add_span_rerank_branch_4.3_test.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Best model": {
3
+ "Ent-I": {
4
+ "precision": 0.7843592330970894,
5
+ "recall": 0.7224649130953411,
6
+ "f1": 0.7521408823268821
7
+ },
8
+ "Ent-C": {
9
+ "precision": 0.7065336658347067,
10
+ "recall": 0.6577212368830367,
11
+ "f1": 0.681254202950345
12
+ }
13
+ },
14
+ "Last model": {
15
+ "Ent-I": {
16
+ "precision": 0.7815895372225536,
17
+ "recall": 0.722093131331237,
18
+ "f1": 0.7506642783058125
19
+ },
20
+ "Ent-C": {
21
+ "precision": 0.7028236229859784,
22
+ "recall": 0.6564212090252981,
23
+ "f1": 0.6788303594359623
24
+ }
25
+ }
26
+ }
4.2_add_span_rerank_branch_4.3/results/4.2_add_span_rerank_branch_4.3_test_df.xlsx ADDED
Binary file (5.28 kB). View file
 
4.2_add_span_rerank_branch_4.3/results/4.2_add_span_rerank_branch_4.3_test_df_best.xlsx ADDED
Binary file (5.28 kB). View file