| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """MNLI dataset.""" |
|
|
| from megatron import print_rank_0 |
| from tasks.data_utils import clean_text |
| from .data import GLUEAbstractDataset |
| import json |
| from tasks.label_dict import get_label_dict |
|
|
| LABELS = get_label_dict("IFLYTEK") |
|
|
| class IFLYTEKDataset(GLUEAbstractDataset): |
|
|
| def __init__(self, name, datapaths, tokenizer, max_seq_length, |
| test_label='0'): |
| self.test_label = test_label |
| super().__init__('IFLYTEK', name, datapaths, |
| tokenizer, max_seq_length) |
|
|
| def process_samples_from_single_path(self, filename): |
| """"Implement abstract method.""" |
| print_rank_0(' > Processing {} ...'.format(filename)) |
|
|
| samples = [] |
| total = 0 |
| first = True |
| is_test = False |
| with open(filename, 'r') as f: |
| reader = f.readlines() |
| lines = [] |
| for line in reader: |
| lines.append(json.loads(line.strip())) |
| drop_cnt = 0 |
| for index, row in enumerate(lines): |
| if "id" not in row: |
| row["id"] = index |
| if first: |
| first = False |
| if "label" not in row: |
| is_test = True |
| print_rank_0( |
| ' reading {}, {} and {} columns and setting ' |
| 'labels to {}'.format( |
| row["id"], row["sentence"].strip(), |
| None, self.test_label)) |
| else: |
| is_test = False |
| print_rank_0(' reading {} , {}, {}, and {} columns ' |
| '...'.format( |
| row["id"], row["sentence"].strip(), |
| None, row["label"].strip())) |
|
|
| text_a = clean_text(row["sentence"].strip()) |
| text_b = None |
| unique_id = int(row["id"]) |
|
|
| if is_test: |
| label = self.test_label |
| else: |
| label = row["label"].strip() |
| |
| assert len(text_a) > 0 |
| |
| assert label in LABELS, "found label {} {}".format(label, row) |
| assert unique_id >= 0 |
|
|
| sample = {'text_a': text_a, |
| 'text_b': text_b, |
| 'label': LABELS[label], |
| 'uid': unique_id} |
| total += 1 |
| samples.append(sample) |
|
|
| if total % 5000 == 0: |
| print_rank_0(' > processed {} so far ...'.format(total)) |
|
|
| print_rank_0(' >> processed {} samples.'.format(len(samples))) |
| print_rank_0(' >> drop {} samples.'.format(drop_cnt)) |
|
|
| return samples |
|
|