yeomtong commited on
Commit
fe216b5
·
verified ·
1 Parent(s): b70354a

Update SRL_preprocessing.py

Browse files
Files changed (1) hide show
  1. SRL_preprocessing.py +6 -6
SRL_preprocessing.py CHANGED
@@ -203,12 +203,12 @@ def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100):
203
  def data_processing_for_loader_conll(
204
  train_conll: str,
205
  dev_conll: Optional[str],
206
- test_conll: Optional[str],
207
  tokenizer,
208
  word_col_idx: int = 3,
209
  srl_first_col_idx: int = 11,
210
  max_length: int = 256
211
- ) -> Tuple[SRLDataset, Optional[SRLDataset], Optional[SRLDataset], Dict[str, int], Dict[int, str]]:
212
  """
213
  Reads train/dev/test .gold_conll files and returns:
214
  train_dataset, dev_dataset, test_dataset, label2id, id2label
@@ -219,10 +219,10 @@ def data_processing_for_loader_conll(
219
  # Load samples
220
  train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx)
221
  dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else []
222
- test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else []
223
 
224
  # Build label maps from ALL splits
225
- all_samples = train_samples + dev_samples + test_samples
226
  label2id = {}
227
  for s in all_samples:
228
  for lab in s.labels:
@@ -233,6 +233,6 @@ def data_processing_for_loader_conll(
233
  # Datasets
234
  train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length)
235
  dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None
236
- test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None
237
 
238
- return train_ds, dev_ds, test_ds, label2id, id2label
 
203
  def data_processing_for_loader_conll(
204
  train_conll: str,
205
  dev_conll: Optional[str],
206
+ # test_conll: Optional[str],
207
  tokenizer,
208
  word_col_idx: int = 3,
209
  srl_first_col_idx: int = 11,
210
  max_length: int = 256
211
+ ) -> Tuple[SRLDataset, Optional[SRLDataset], Dict[str, int], Dict[int, str]]:
212
  """
213
  Reads train/dev/test .gold_conll files and returns:
214
  train_dataset, dev_dataset, test_dataset, label2id, id2label
 
219
  # Load samples
220
  train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx)
221
  dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else []
222
+ # test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else []
223
 
224
  # Build label maps from ALL splits
225
+ all_samples = train_samples + dev_samples
226
  label2id = {}
227
  for s in all_samples:
228
  for lab in s.labels:
 
233
  # Datasets
234
  train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length)
235
  dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None
236
+ # test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None
237
 
238
+ return train_ds, dev_ds, label2id, id2label