aaljabari commited on
Commit
ef877e8
·
verified ·
1 Parent(s): a80b8f3

Create data.py

Browse files
Files changed (1) hide show
  1. Nested/utils/data.py +137 -0
Nested/utils/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from collections import Counter, namedtuple
3
+ import logging
4
+ import re
5
+ import itertools
6
+ from Nested.utils.helpers import load_object
7
+ from Nested.data.datasets import Token
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class Vocab:
13
+ def __init__(self, counter, specials=[]) -> None:
14
+ self.itos = list(counter.keys()) + specials
15
+ self.stoi = {s: i for i, s in enumerate(self.itos)}
16
+ self.word_count = counter
17
+
18
+ def get_itos(self) -> list[str]:
19
+ return self.itos
20
+
21
+ def get_stoi(self) -> dict[str, int]:
22
+ return self.stoi
23
+
24
+ def __len__(self):
25
+ return len(self.itos)
26
+
27
+
28
+ def conll_to_segments(filename):
29
+ """
30
+ Convert CoNLL files to segments. This return list of segments and each segment is
31
+ a list of tuples (token, tag)
32
+ :param filename: Path
33
+ :return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
34
+ """
35
+ segments, segment = list(), list()
36
+
37
+ with open(filename, "r") as fh:
38
+ for token in fh.read().splitlines():
39
+ if not token.strip():
40
+ segments.append(segment)
41
+ segment = list()
42
+ else:
43
+ parts = token.split()
44
+ token = Token(text=parts[0], gold_tag=parts[1:])
45
+ segment.append(token)
46
+
47
+ segments.append(segment)
48
+
49
+ return segments
50
+
51
+
52
+ def parse_conll_files(data_paths):
53
+ """
54
+ Parse CoNLL formatted files and return list of segments for each file and index
55
+ the vocabs and tags across all data_paths
56
+ :param data_paths: tuple(Path) - tuple of filenames
57
+ :return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
58
+ [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
59
+ ...
60
+ )
61
+ List of segments for each dataset and each segment has list of (tokens, tags)
62
+ """
63
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
64
+ datasets, tags, tokens = list(), list(), list()
65
+
66
+ for data_path in data_paths:
67
+ dataset = conll_to_segments(data_path)
68
+ datasets.append(dataset)
69
+ tokens += [token.text for segment in dataset for token in segment]
70
+ tags += [token.gold_tag for segment in dataset for token in segment]
71
+
72
+ # Flatten list of tags
73
+ tags = list(itertools.chain(*tags))
74
+
75
+ # Generate vocabs for tags and tokens
76
+ tag_vocabs = tag_vocab_by_type(tags)
77
+ tag_vocabs.insert(0, Vocab(Counter(tags)))
78
+ vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
79
+ return tuple(datasets), vocabs
80
+
81
+
82
+ def tag_vocab_by_type(tags):
83
+ vocabs = list()
84
+ c = Counter(tags)
85
+ tag_names = c.keys()
86
+ tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
87
+
88
+ for tag_type in tag_types:
89
+ r = re.compile(".*-" + tag_type + "$")
90
+ t = list(filter(r.match, tags)) + ["O"]
91
+ vocabs.append(Vocab(Counter(t)))
92
+
93
+ return vocabs
94
+
95
+
96
+ def text2segments(text):
97
+ """
98
+ Convert text to a datasets and index the tokens
99
+ """
100
+ dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
101
+ tokens = [token.text for segment in dataset for token in segment]
102
+
103
+ # Generate vocabs for the tokens
104
+ segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
105
+ return dataset, segment_vocab
106
+
107
+
108
+ def get_dataloaders(
109
+ datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
110
+ ):
111
+ """
112
+ From the datasets generate the dataloaders
113
+ :param datasets: list - list of the datasets, list of list of segments and tokens
114
+ :param batch_size: int
115
+ :param num_workers: int
116
+ :param shuffle: boolean - to shuffle the data or not
117
+ :return: List[torch.utils.data.DataLoader]
118
+ """
119
+ dataloaders = list()
120
+
121
+ data_config = data_config["data_config"]
122
+ for i, examples in enumerate(datasets):
123
+ data_config["kwargs"].update({"examples": examples, "vocab": vocab})
124
+ dataset = load_object(data_config["fn"], data_config["kwargs"])
125
+
126
+ dataloader = DataLoader(
127
+ dataset=dataset,
128
+ shuffle=shuffle[i],
129
+ batch_size=batch_size,
130
+ num_workers=num_workers,
131
+ collate_fn=dataset.collate_fn,
132
+ )
133
+
134
+ logger.info("%s batches found", len(dataloader))
135
+ dataloaders.append(dataloader)
136
+
137
+ return dataloaders