Update SRL_preprocessing.py
Browse files- SRL_preprocessing.py +6 -6
SRL_preprocessing.py
CHANGED
|
@@ -203,12 +203,12 @@ def srl_collate(batch: List[Dict], pad_token_id: int, pad_label_id: int = -100):
|
|
| 203 |
def data_processing_for_loader_conll(
|
| 204 |
train_conll: str,
|
| 205 |
dev_conll: Optional[str],
|
| 206 |
-
test_conll: Optional[str],
|
| 207 |
tokenizer,
|
| 208 |
word_col_idx: int = 3,
|
| 209 |
srl_first_col_idx: int = 11,
|
| 210 |
max_length: int = 256
|
| 211 |
-
) -> Tuple[SRLDataset, Optional[SRLDataset],
|
| 212 |
"""
|
| 213 |
Reads train/dev/test .gold_conll files and returns:
|
| 214 |
train_dataset, dev_dataset, test_dataset, label2id, id2label
|
|
@@ -219,10 +219,10 @@ def data_processing_for_loader_conll(
|
|
| 219 |
# Load samples
|
| 220 |
train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx)
|
| 221 |
dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else []
|
| 222 |
-
test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else []
|
| 223 |
|
| 224 |
# Build label maps from ALL splits
|
| 225 |
-
all_samples = train_samples + dev_samples
|
| 226 |
label2id = {}
|
| 227 |
for s in all_samples:
|
| 228 |
for lab in s.labels:
|
|
@@ -233,6 +233,6 @@ def data_processing_for_loader_conll(
|
|
| 233 |
# Datasets
|
| 234 |
train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length)
|
| 235 |
dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None
|
| 236 |
-
test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None
|
| 237 |
|
| 238 |
-
return train_ds, dev_ds,
|
|
|
|
| 203 |
def data_processing_for_loader_conll(
|
| 204 |
train_conll: str,
|
| 205 |
dev_conll: Optional[str],
|
| 206 |
+
# test_conll: Optional[str],
|
| 207 |
tokenizer,
|
| 208 |
word_col_idx: int = 3,
|
| 209 |
srl_first_col_idx: int = 11,
|
| 210 |
max_length: int = 256
|
| 211 |
+
) -> Tuple[SRLDataset, Optional[SRLDataset], Dict[str, int], Dict[int, str]]:
|
| 212 |
"""
|
| 213 |
Reads train/dev/test .gold_conll files and returns:
|
| 214 |
train_dataset, dev_dataset, test_dataset, label2id, id2label
|
|
|
|
| 219 |
# Load samples
|
| 220 |
train_samples = load_conll_samples(train_conll, word_col_idx, srl_first_col_idx)
|
| 221 |
dev_samples = load_conll_samples(dev_conll, word_col_idx, srl_first_col_idx) if dev_conll else []
|
| 222 |
+
# test_samples = load_conll_samples(test_conll, word_col_idx, srl_first_col_idx) if test_conll else []
|
| 223 |
|
| 224 |
# Build label maps from ALL splits
|
| 225 |
+
all_samples = train_samples + dev_samples
|
| 226 |
label2id = {}
|
| 227 |
for s in all_samples:
|
| 228 |
for lab in s.labels:
|
|
|
|
| 233 |
# Datasets
|
| 234 |
train_ds = SRLDataset(train_samples, tokenizer, label2id, max_length=max_length)
|
| 235 |
dev_ds = SRLDataset(dev_samples, tokenizer, label2id, max_length=max_length) if dev_samples else None
|
| 236 |
+
# test_ds = SRLDataset(test_samples, tokenizer, label2id, max_length=max_length) if test_samples else None
|
| 237 |
|
| 238 |
+
return train_ds, dev_ds, label2id, id2label
|