TymaaHammouda's picture
Remove print statments and reformat the results
8267058
from torch.utils.data import DataLoader
from collections import Counter, namedtuple
import logging
import re
import itertools
from Nested.utils.helpers import load_object
from Nested.data.datasets import Token
logger = logging.getLogger(__name__)
class Vocab:
def __init__(self, counter, specials=[]) -> None:
self.itos = list(counter.keys()) + specials
self.stoi = {s: i for i, s in enumerate(self.itos)}
self.word_count = counter
def get_itos(self) -> list[str]:
return self.itos
def get_stoi(self) -> dict[str, int]:
return self.stoi
def __len__(self):
return len(self.itos)
def conll_to_segments(filename):
"""
Convert CoNLL files to segments. This return list of segments and each segment is
a list of tuples (token, tag)
:param filename: Path
:return: list[[tuple]] - [[(token, tag), (token, tag), ...], [(token, tag), ...]]
"""
segments, segment = list(), list()
with open(filename, "r") as fh:
for token in fh.read().splitlines():
if not token.strip():
segments.append(segment)
segment = list()
else:
parts = token.split()
token = Token(text=parts[0], gold_tag=parts[1:])
segment.append(token)
segments.append(segment)
return segments
def parse_conll_files(data_paths):
"""
Parse CoNLL formatted files and return list of segments for each file and index
the vocabs and tags across all data_paths
:param data_paths: tuple(Path) - tuple of filenames
:return: tuple( [[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i]
[[(token, tag), ...], [(token, tag), ...]], -> segments for data_paths[i+1],
...
)
List of segments for each dataset and each segment has list of (tokens, tags)
"""
vocabs = namedtuple("Vocab", ["tags", "tokens"])
datasets, tags, tokens = list(), list(), list()
for data_path in data_paths:
dataset = conll_to_segments(data_path)
datasets.append(dataset)
tokens += [token.text for segment in dataset for token in segment]
tags += [token.gold_tag for segment in dataset for token in segment]
# Flatten list of tags
tags = list(itertools.chain(*tags))
# Generate vocabs for tags and tokens
tag_vocabs = tag_vocab_by_type(tags)
tag_vocabs.insert(0, Vocab(Counter(tags)))
vocabs = vocabs(tokens=Vocab(Counter(tokens), specials=["UNK"]), tags=tag_vocabs)
return tuple(datasets), vocabs
def tag_vocab_by_type(tags):
vocabs = list()
c = Counter(tags)
tag_names = c.keys()
tag_types = sorted(list(set([tag.split("-", 1)[1] for tag in tag_names if "-" in tag])))
for tag_type in tag_types:
r = re.compile(".*-" + tag_type + "$")
t = list(filter(r.match, tags)) + ["O"]
vocabs.append(Vocab(Counter(t)))
return vocabs
def text2segments(text):
"""
Convert text to a datasets and index the tokens
"""
dataset = [[Token(text=token, gold_tag=["O"]) for token in text.split()]]
tokens = [token.text for segment in dataset for token in segment]
# Generate vocabs for the tokens
segment_vocab = Vocab(Counter(tokens), specials=["UNK"])
return dataset, segment_vocab
def get_dataloaders(
datasets, vocab, data_config, batch_size=32, num_workers=0, shuffle=(True, False, False)
):
"""
From the datasets generate the dataloaders
:param datasets: list - list of the datasets, list of list of segments and tokens
:param batch_size: int
:param num_workers: int
:param shuffle: boolean - to shuffle the data or not
:return: List[torch.utils.data.DataLoader]
"""
dataloaders = list()
data_config = data_config["data_config"]
for i, examples in enumerate(datasets):
data_config["kwargs"].update({"examples": examples, "vocab": vocab})
dataset = load_object(data_config["fn"], data_config["kwargs"])
dataloader = DataLoader(
dataset=dataset,
shuffle=shuffle[i],
batch_size=batch_size,
num_workers=num_workers,
collate_fn=dataset.collate_fn,
)
logger.info("%s batches found", len(dataloader))
dataloaders.append(dataloader)
return dataloaders