Spaces:
Running
Running
| from torch.utils.data import DataLoader | |
| from collections import Counter, namedtuple | |
| import logging | |
| import re | |
| import itertools | |
| from Nested.utils.helpers import load_object | |
| from Nested.data.datasets import Token | |
| logger = logging.getLogger(__name__) | |
| class Vocab: | |
| def __init__(self, counter, specials=[]) -> None: | |
| self.itos = list(counter.keys()) + specials | |
| self.stoi = {s: i for i, s in enumerate(self.itos)} | |
| self.word_count = counter | |
| def get_itos(self) -> list[str]: | |
| return self.itos | |
| def get_stoi(self) -> dict[str, int]: | |
| return self.stoi | |
| def __len__(self): | |
| return len(self.itos) | |
| def conll_to_segments(filename): | |
| """ | |
| Convert CoNLL files to segments. This return list of segments and each segment is | |
| a list of tuples (token, tag) | |
| :param filename: Path | |
| :return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]] | |
| """ | |
| segments, segment = list(), list() | |
| with open(filename, "r") as fh: | |
| for token in fh.read().splitlines(): | |
| if not token.strip(): | |
| segments.append(segment) | |
| segment = list() | |
| else: | |
| parts = token.split() | |
| token = Token(text=parts[0], gold_tag=parts[1:]) | |
| segment.append(token) | |
| segments.append(segment) | |
| return segments | |
| def parse_conll_files(data_paths): | |
| """ | |
| Parse CoNLL formatted files and return list of segments for each file and index | |
| the vocabs and tags across all data_paths | |
| :param data_paths: tuple(Path) - tuple of filenames | |
| :return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i] | |
| [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1], | |
| ... | |
| ) | |
| List of segments for each dataset and each segment has list of (tokens, tags) | |
| """ | |
| vocabs = namedtuple("Vocab", ["tags", "tokens"]) | |
| datasets, tags, tokens = list(), list(), list() | |
| for data_path in data_paths: | |
| dataset = conll_to_segments(data_path) | |
| datasets.append(dataset) | |
| tokens += [token.text for segment in dataset for token in segment] | |
| tags += [token.gold_tag for segment in dataset for token in segment] | |
| # Flatten list of tags | |
| tags = list(itertools.chain(*tags)) | |
| # Generate vocabs for tags and tokens | |
| tag_vocabs = tag_vocab_by_type(tags) | |
| tag_vocabs.insert(0, Vocab(Counter(tags))) | |
| vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs) | |
| return tuple(datasets), vocabs | |
| def tag_vocab_by_type(tags): | |
| vocabs = list() | |
| c = Counter(tags) | |
| tag_names = c.keys() | |
| tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag]))) | |
| for tag_type in tag_types: | |
| r = re.compile(".*-" + tag_type + "$") | |
| t = list(filter(r.match, tags)) + ["O"] | |
| vocabs.append(Vocab(Counter(t))) | |
| return vocabs | |
| def text2segments(text): | |
| """ | |
| Convert text to a datasets and index the tokens | |
| """ | |
| dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]] | |
| tokens = [token.text for segment in dataset for token in segment] | |
| # Generate vocabs for the tokens | |
| segment_vocab = Vocab(Counter(tokens), specials=["UNK"]) | |
| return dataset, segment_vocab | |
| def get_dataloaders( | |
| datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False) | |
| ): | |
| """ | |
| From the datasets generate the dataloaders | |
| :param datasets: list - list of the datasets, list of list of segments and tokens | |
| :param batch_size: int | |
| :param num_workers: int | |
| :param shuffle: boolean - to shuffle the data or not | |
| :return: List[torch.utils.data.DataLoader] | |
| """ | |
| dataloaders = list() | |
| data_config = data_config["data_config"] | |
| for i, examples in enumerate(datasets): | |
| data_config["kwargs"].update({"examples": examples, "vocab": vocab}) | |
| dataset = load_object(data_config["fn"], data_config["kwargs"]) | |
| dataloader = DataLoader( | |
| dataset=dataset, | |
| shuffle=shuffle[i], | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| logger.info("%s batches found", len(dataloader)) | |
| dataloaders.append(dataloader) | |
| return dataloaders | |