Spaces:
Running
Running
File size: 4,416 Bytes
f316449 2ba7df1 d6fe8b7 f316449 d6fe8b7 f316449 2ba7df1 f316449 2ba7df1 f316449 2ba7df1 f316449 d2ab44f f316449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from torch.utils.data import DataLoader
from collections import Counter, namedtuple
import logging
import re
import itertools
from Nested.utils.helpers import load_object
from Nested.data.datasets import Token
logger = logging.getLogger(__name__)
class Vocab:
def __init__(self, counter, specials=[]) -> None:
self.itos = list(counter.keys()) + specials
self.stoi = {s: i for i, s in enumerate(self.itos)}
self.word_count = counter
def get_itos(self) -> list[str]:
return self.itos
def get_stoi(self) -> dict[str, int]:
return self.stoi
def __len__(self):
return len(self.itos)
def conll_to_segments(filename):
"""
Convert CoNLL files to segments. This return list of segments and each segment is
a list of tuples (token, tag)
:param filename: Path
:return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
"""
segments, segment = list(), list()
with open(filename, "r") as fh:
for token in fh.read().splitlines():
if not token.strip():
segments.append(segment)
segment = list()
else:
parts = token.split()
token = Token(text=parts[0], gold_tag=parts[1:])
segment.append(token)
segments.append(segment)
return segments
def parse_conll_files(data_paths):
"""
Parse CoNLL formatted files and return list of segments for each file and index
the vocabs and tags across all data_paths
:param data_paths: tuple(Path) - tuple of filenames
:return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
[[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
...
)
List of segments for each dataset and each segment has list of (tokens, tags)
"""
vocabs = namedtuple("Vocab", ["tags", "tokens"])
datasets, tags, tokens = list(), list(), list()
for data_path in data_paths:
dataset = conll_to_segments(data_path)
datasets.append(dataset)
tokens += [token.text for segment in dataset for token in segment]
tags += [token.gold_tag for segment in dataset for token in segment]
# Flatten list of tags
tags = list(itertools.chain(*tags))
# Generate vocabs for tags and tokens
tag_vocabs = tag_vocab_by_type(tags)
tag_vocabs.insert(0, Vocab(Counter(tags)))
vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
return tuple(datasets), vocabs
def tag_vocab_by_type(tags):
vocabs = list()
c = Counter(tags)
tag_names = c.keys()
tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
for tag_type in tag_types:
r = re.compile(".*-" + tag_type + "$")
t = list(filter(r.match, tags)) + ["O"]
vocabs.append(Vocab(Counter(t)))
return vocabs
def text2segments(text):
"""
Convert text to a datasets and index the tokens
"""
dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
tokens = [token.text for segment in dataset for token in segment]
# Generate vocabs for the tokens
segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
return dataset, segment_vocab
def get_dataloaders(
datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
):
"""
From the datasets generate the dataloaders
:param datasets: list - list of the datasets, list of list of segments and tokens
:param batch_size: int
:param num_workers: int
:param shuffle: boolean - to shuffle the data or not
:return: List[torch.utils.data.DataLoader]
"""
dataloaders = list()
data_config = data_config["data_config"]
for i, examples in enumerate(datasets):
data_config["kwargs"].update({"examples": examples, "vocab": vocab})
dataset = load_object(data_config["fn"], data_config["kwargs"])
dataloader = DataLoader(
dataset=dataset,
shuffle=shuffle[i],
batch_size=batch_size,
num_workers=num_workers,
collate_fn=dataset.collate_fn,
)
logger.info("%s batches found", len(dataloader))
dataloaders.append(dataloader)
return dataloaders
|