File size: 4,416 Bytes
f316449
 
 
 
 
 
 
 
 
 
2ba7df1
d6fe8b7
f316449
d6fe8b7
 
 
 
 
 
 
 
 
 
f316449
2ba7df1
 
 
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ba7df1
 
 
 
 
 
 
f316449
2ba7df1
 
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ab44f
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from torch.utils.data import DataLoader
from collections import Counter, namedtuple
import logging
import re
import itertools
from Nested.utils.helpers import load_object
from Nested.data.datasets import Token

logger = logging.getLogger(__name__)


class Vocab:
    def __init__(self, counter, specials=[]) -> None:
        self.itos = list(counter.keys()) + specials
        self.stoi = {s: i for i, s in enumerate(self.itos)}
        self.word_count = counter

    def get_itos(self) -> list[str]:
        return self.itos

    def get_stoi(self) -> dict[str, int]:
        return self.stoi

    def __len__(self):
        return len(self.itos)


def conll_to_segments(filename):
    """
    Convert CoNLL files to segments. This return list of segments and each segment is
    a list of tuples (token, tag)
    :param filename: Path
    :return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
    """
    segments, segment = list(), list()

    with open(filename, "r") as fh:
        for token in fh.read().splitlines():
            if not token.strip():
                segments.append(segment)
                segment = list()
            else:
                parts = token.split()
                token = Token(text=parts[0], gold_tag=parts[1:])
                segment.append(token)

        segments.append(segment)

    return segments


def parse_conll_files(data_paths):
    """
    Parse CoNLL formatted files and return list of segments for each file and index
    the vocabs and tags across all data_paths
    :param data_paths: tuple(Path) - tuple of filenames
    :return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
                    [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
                    ...
                  )
             List of segments for each dataset and each segment has list of (tokens, tags)
    """
    vocabs = namedtuple("Vocab", ["tags", "tokens"])
    datasets, tags, tokens = list(), list(), list()

    for data_path in data_paths:
        dataset = conll_to_segments(data_path)
        datasets.append(dataset)
        tokens += [token.text for segment in dataset for token in segment]
        tags += [token.gold_tag for segment in dataset for token in segment]

    # Flatten list of tags
    tags = list(itertools.chain(*tags))

    # Generate vocabs for tags and tokens
    tag_vocabs = tag_vocab_by_type(tags)
    tag_vocabs.insert(0, Vocab(Counter(tags)))
    vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
    return tuple(datasets), vocabs


def tag_vocab_by_type(tags):
    vocabs = list()
    c = Counter(tags)
    tag_names = c.keys()
    tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))

    for tag_type in tag_types:
        r = re.compile(".*-" + tag_type + "$")
        t = list(filter(r.match, tags)) + ["O"]
        vocabs.append(Vocab(Counter(t)))

    return vocabs


def text2segments(text):
    """
    Convert text to a datasets and index the tokens
    """
    dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
    tokens = [token.text for segment in dataset for token in segment]

    # Generate vocabs for the tokens
    segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
    return dataset, segment_vocab


def get_dataloaders(
    datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
):
    """
    From the datasets generate the dataloaders
    :param datasets: list - list of the datasets, list of list of segments and tokens
    :param batch_size: int
    :param num_workers: int
    :param shuffle: boolean - to shuffle the data or not
    :return: List[torch.utils.data.DataLoader]
    """
    dataloaders = list()

    data_config = data_config["data_config"]
    for i, examples in enumerate(datasets):
        data_config["kwargs"].update({"examples": examples, "vocab": vocab})
        dataset = load_object(data_config["fn"], data_config["kwargs"])

        dataloader = DataLoader(
            dataset=dataset,
            shuffle=shuffle[i],
            batch_size=batch_size,
            num_workers=num_workers,
            collate_fn=dataset.collate_fn,
        )

        logger.info("%s batches found", len(dataloader))
        dataloaders.append(dataloader)

    return dataloaders