Update trainer.py
Browse files- trainer.py +17 -9
trainer.py
CHANGED
|
@@ -187,6 +187,7 @@ if __name__ == "__main__":
|
|
| 187 |
# read values from cfg as usual:
|
| 188 |
conll_train_path = cfg["data"]["conll_train"]
|
| 189 |
conll_dev_path = cfg["data"].get("conll_dev")
|
|
|
|
| 190 |
word_col_idx = cfg["data"]["word_col_idx"]
|
| 191 |
srl_first_col_idx= cfg["data"]["srl_first_col_idx"]
|
| 192 |
|
|
@@ -215,13 +216,19 @@ if __name__ == "__main__":
|
|
| 215 |
tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
|
| 216 |
print(f"Using tokenizer: {replace_encoder_with or bert_name}")
|
| 217 |
|
| 218 |
-
print(f"Loading multilingual CoNLL data: {conll_train_path}")
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
# pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
| 227 |
|
|
@@ -241,8 +248,9 @@ if __name__ == "__main__":
|
|
| 241 |
|
| 242 |
collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
|
| 243 |
|
| 244 |
-
train_loader = DataLoader(
|
| 245 |
-
dev_loader
|
|
|
|
| 246 |
|
| 247 |
# ------------------------------
|
| 248 |
# 🧠 Model initialization
|
|
|
|
| 187 |
# read values from cfg as usual:
|
| 188 |
conll_train_path = cfg["data"]["conll_train"]
|
| 189 |
conll_dev_path = cfg["data"].get("conll_dev")
|
| 190 |
+
conll_test_path = cfg["data"].get("conll_test")
|
| 191 |
word_col_idx = cfg["data"]["word_col_idx"]
|
| 192 |
srl_first_col_idx= cfg["data"]["srl_first_col_idx"]
|
| 193 |
|
|
|
|
| 216 |
tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
|
| 217 |
print(f"Using tokenizer: {replace_encoder_with or bert_name}")
|
| 218 |
|
| 219 |
+
# print(f"Loading multilingual CoNLL data: {conll_train_path}")
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label = \
|
| 223 |
+
data_processing_for_loader_conll(
|
| 224 |
+
train_conll=conll_train_path,
|
| 225 |
+
dev_conll=conll_dev_path,
|
| 226 |
+
test_conll=conll_test_path,
|
| 227 |
+
tokenizer=tokenizer,
|
| 228 |
+
word_col_idx=word_col_idx,
|
| 229 |
+
srl_first_col_idx=srl_first_col_idx,
|
| 230 |
+
max_length=256,
|
| 231 |
+
)
|
| 232 |
|
| 233 |
# pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
|
| 234 |
|
|
|
|
| 248 |
|
| 249 |
collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
|
| 250 |
|
| 251 |
+
train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True, collate_fn=collate)
|
| 252 |
+
dev_loader = DataLoader(dev_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None
|
| 253 |
+
test_loader = DataLoader(test_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if test_bf_loader else None
|
| 254 |
|
| 255 |
# ------------------------------
|
| 256 |
# 🧠 Model initialization
|