yeomtong commited on
Commit
d117d6b
·
verified ·
1 Parent(s): c7804f9

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +17 -9
trainer.py CHANGED
@@ -187,6 +187,7 @@ if __name__ == "__main__":
187
  # read values from cfg as usual:
188
  conll_train_path = cfg["data"]["conll_train"]
189
  conll_dev_path = cfg["data"].get("conll_dev")
 
190
  word_col_idx = cfg["data"]["word_col_idx"]
191
  srl_first_col_idx= cfg["data"]["srl_first_col_idx"]
192
 
@@ -215,13 +216,19 @@ if __name__ == "__main__":
215
  tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
216
  print(f"Using tokenizer: {replace_encoder_with or bert_name}")
217
 
218
- print(f"Loading multilingual CoNLL data: {conll_train_path}")
219
- train_dataset, label2id, id2label = data_processing_for_loader_from_conll(
220
- conll_path=conll_train_path,
221
- tokenizer=tokenizer,
222
- word_col_idx=word_col_idx,
223
- srl_first_col_idx=srl_first_col_idx,
224
- )
 
 
 
 
 
 
225
 
226
  # pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
227
 
@@ -241,8 +248,9 @@ if __name__ == "__main__":
241
 
242
  collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
243
 
244
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
245
- dev_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate) # no dev split yet
 
246
 
247
  # ------------------------------
248
  # 🧠 Model initialization
 
187
  # read values from cfg as usual:
188
  conll_train_path = cfg["data"]["conll_train"]
189
  conll_dev_path = cfg["data"].get("conll_dev")
190
+ conll_test_path = cfg["data"].get("conll_test")
191
  word_col_idx = cfg["data"]["word_col_idx"]
192
  srl_first_col_idx= cfg["data"]["srl_first_col_idx"]
193
 
 
216
  tokenizer = AutoTokenizer.from_pretrained(replace_encoder_with or bert_name)
217
  print(f"Using tokenizer: {replace_encoder_with or bert_name}")
218
 
219
+ # print(f"Loading multilingual CoNLL data: {conll_train_path}")
220
+
221
+
222
+ train_bf_loader, dev_bf_loader, test_bf_loader, label2id, id2label = \
223
+ data_processing_for_loader_conll(
224
+ train_conll=conll_train_path,
225
+ dev_conll=conll_dev_path,
226
+ test_conll=conll_test_path,
227
+ tokenizer=tokenizer,
228
+ word_col_idx=word_col_idx,
229
+ srl_first_col_idx=srl_first_col_idx,
230
+ max_length=256,
231
+ )
232
 
233
  # pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
234
 
 
248
 
249
  collate = lambda b: srl_collate(b, pad_token_id=pad_token_id, pad_label_id=-100)
250
 
251
+ train_loader = DataLoader(train_bf_loader, batch_size=batch_size, shuffle=True, collate_fn=collate)
252
+ dev_loader = DataLoader(dev_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if dev_bf_loader else None
253
+ test_loader = DataLoader(test_bf_loader, batch_size=batch_size, shuffle=False, collate_fn=collate) if test_bf_loader else None
254
 
255
  # ------------------------------
256
  # 🧠 Model initialization