| | import os |
| | import json |
| | import numpy as np |
| | import torch |
| |
|
| | from utils.poet_utils import StropheParams, SyllableMaker, TextAnalysis, TextManipulation |
| | from torch.utils.data import Dataset |
| | from transformers import PreTrainedTokenizerBase, PreTrainedModel |
| | |
| | class CorpusDatasetPytorch: |
| | """Dataset class responsible for data loading. |
| | """ |
| | |
| | class RawDataset: |
| | """Dataset distributing raw sting data with no preprocessing |
| | """ |
| | def __init__(self, data_file_paths, lower_case:bool = True): |
| | """Construct the frame around Raw data generation |
| | |
| | Args: |
| | data_file_paths (_type_): list of paths to data files |
| | lower_case (bool, optional): if resulting data should be in lowercase. Defaults to True. |
| | """ |
| | self._data_file_paths = data_file_paths |
| | self.lower_case = lower_case |
| | |
| | def gen_files(self): |
| | """Get individual opened files |
| | |
| | Yields: |
| | _type_: open file object |
| | """ |
| | for filename in self._data_file_paths: |
| | yield open(filename, 'r') |
| | |
| | def get_text(self): |
| | """Get lines of text of poetry |
| | |
| | Yields: |
| | str: individual verse line |
| | """ |
| | for step,file in enumerate(self.gen_files()): |
| | if step % 500 == 0: |
| | print(f"Processing file {step}") |
| | datum = json.load(file) |
| | for data_line in datum: |
| | for part_line in data_line['body']: |
| | for text_line in part_line: |
| | yield text_line['text'].lower() if self.lower_case else text_line['text'] |
| | |
| | def get_part(self): |
| | """Get strophe of poetry |
| | |
| | Yields: |
| | str: 1 strophe of poetry |
| | """ |
| | for step,file in enumerate(self.gen_files()): |
| | if step % 500 == 0: |
| | print(f"Processing file {step}") |
| | datum = json.load(file) |
| | for data_line in datum: |
| | for part_line in data_line['body']: |
| | part = [] |
| | for text_line in part_line: |
| | part.append(text_line['text']) |
| | yield "\n".join(part).lower() if self.lower_case else "\n".join(part) |
| | |
| | def get_body(self): |
| | """Get whole poem |
| | |
| | Yields: |
| | str: 1 whole poem |
| | """ |
| | for step,file in enumerate(self.gen_files()): |
| | if step % 500 == 0: |
| | print(f"Processing file {step}") |
| | datum = json.load(file) |
| | for data_line in datum: |
| | body = [] |
| | for part_line in data_line['body']: |
| | |
| | for text_line in part_line: |
| | body.append(text_line['text']) |
| | body.append("\n") |
| | yield "\n".join(body).lower() if self.lower_case else "\n".join(body) |
| | |
| | class TextDataset(Dataset): |
| | """Dataset of preprocessed verse lines |
| | |
| | Args: |
| | Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface |
| | """ |
| | |
| | def __init__(self, data_file_paths, prompt_length=True, prompt_ending=True, lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05): |
| | """Construct the class our given data files path and store variables |
| | |
| | Args: |
| | data_file_paths (_type_): list of paths to data files |
| | prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
| | prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
| | lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
| | val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
| | test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
| | """ |
| | self._data_file_paths = data_file_paths |
| | self.prompt_length = prompt_length |
| | self.prompt_ending = prompt_ending |
| | self.lower_case = lower_case |
| | |
| | self.val_data_rate = val_data_rate |
| | self.test_data_rate = test_data_rate |
| | |
| | self.data = [] |
| | self.validation_data = [] |
| | self.test_data = [] |
| | |
| | |
| | def gen_files(self): |
| | """Get individual opened files |
| | |
| | Yields: |
| | _type_: open file object |
| | """ |
| | for filename in self._data_file_paths: |
| | yield open(filename, 'r') |
| | |
| | @staticmethod |
| | def _vowels_and_endings(raw_text): |
| | """Get the verse ending and number of syllables in verse |
| | |
| | Args: |
| | raw_text (str): raw verse to analyze |
| | |
| | Returns: |
| | tuple: number of syllables, ending syllable |
| | """ |
| | syllabs = SyllableMaker.syllabify(raw_text) |
| | vowels = len(syllabs) |
| | ending = syllabs[-1] |
| | return vowels, ending |
| | |
| | @staticmethod |
| | def _ending_vector(end): |
| | """Construct One-hot encoded vector for ending syllable |
| | |
| | Args: |
| | end (str): Ending syllable |
| | |
| | Returns: |
| | numpy.ndarray: One-hot encoded vector of ending syllable |
| | """ |
| | verse_end_vector = np.zeros(len(StropheParams.ENDS)) |
| | if end in StropheParams.ENDS[:-1]: |
| | verse_end_vector[StropheParams.ENDS.index(end)] = 1 |
| | else: |
| | verse_end_vector[-1] = 1 |
| | return verse_end_vector |
| | |
| | @staticmethod |
| | def _syllable_line(raw_text): |
| | """Construct verse as sequence of syllables |
| | |
| | Args: |
| | raw_text (str): raw verse line |
| | |
| | Returns: |
| | str: Verse line as sequence of syllables |
| | """ |
| | ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
| | return " ".join(SyllableMaker.syllabify(raw_text)) + ending |
| | |
| | def _construct_line(self, raw_text, metre): |
| | """Construct individual content line |
| | |
| | Args: |
| | raw_text (str): raw verse line |
| | |
| | Returns: |
| | str: Processed verse line with line parameters |
| | """ |
| | syllables = SyllableMaker.syllabify(raw_text) |
| | num_str = f"{len(syllables)} # " if self.prompt_length else "" |
| | verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
| | metre_txt = f"{metre} # " |
| | return metre_txt + num_str + verse_end + raw_text |
| | |
| | def _introduce_phonetics(self, raw_text:str, phonetics): |
| | phonetic_text = raw_text |
| | for word in phonetics['words']: |
| | phonetic_text = phonetic_text.replace(f'{word["token_lc"]}', f'{word["phoebe"]}') if self.lower_case else phonetic_text.replace(f'{word["token"]}', f'{word["phoebe"]}') |
| | return phonetic_text |
| | |
| | def _construct_syllable_line(self, raw_text, metre): |
| | """Construct individual content line as sequence of syllables |
| | |
| | Args: |
| | raw_text (str): raw verse line |
| | |
| | Returns: |
| | str: Processed verse line as sequence of syllables with line parameters |
| | """ |
| | ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
| | syllables = SyllableMaker.syllabify(raw_text) |
| | num_str = f"{len(syllables)} # " if self.prompt_length else "" |
| | verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
| | metre_txt = f"{metre} # " |
| | return metre_txt+ num_str + verse_end + " ".join(syllables) + ending |
| | |
| | |
| | def data_text_line_gen(self): |
| | """Preprocess and process data for usage |
| | """ |
| | for step,file in enumerate(self.gen_files()): |
| | if step % 500 == 0: |
| | print(f"Processing file {step}") |
| | datum = json.load(file) |
| | for data_line in datum: |
| | for part_line in data_line['body']: |
| | for text_line in part_line: |
| | metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "N") |
| | |
| | scanned_text = TextManipulation._remove_most_nonchar(text_line['text'], self.lower_case) |
| | |
| | text_line_scanned = self._construct_line(scanned_text, metre) |
| | syllable_line = self._construct_syllable_line(scanned_text, metre) |
| | |
| | |
| | num_vowels, verse_end = self._vowels_and_endings(scanned_text) |
| | |
| | |
| | rand_split = np.random.rand() |
| | if rand_split > self.val_data_rate + self.test_data_rate: |
| | self.data.append({ |
| | "input_ids" : [text_line_scanned,syllable_line], |
| | "nums": [num_vowels], |
| | "verse_end": verse_end, |
| | "metre": metre |
| | }) |
| | elif rand_split < self.test_data_rate: |
| | self.test_data.append({ |
| | "input_ids" : [text_line_scanned,syllable_line], |
| | "nums": [num_vowels], |
| | "verse_end": verse_end, |
| | "metre": metre |
| | }) |
| | else: |
| | self.validation_data.append({ |
| | "input_ids" : [text_line_scanned,syllable_line], |
| | "nums": [num_vowels], |
| | "verse_end": verse_end, |
| | "metre": metre |
| | }) |
| | |
| | |
| | def __len__(self): |
| | """Return length of training data |
| | |
| | Returns: |
| | int: length of training data |
| | """ |
| | return len(self.data) |
| | |
| | def __getitem__(self, index): |
| | """return indexed item |
| | |
| | Args: |
| | index (int): index from where to return |
| | |
| | Returns: |
| | dict: dict with indexed data |
| | """ |
| | return self.data[index] |
| | |
| | class BodyDataset(Dataset): |
| | """Dataset of preprocessed strophe |
| | |
| | Args: |
| | Dataset (_type_): Dataset is child of torch class for better integration with torch and huggingface |
| | """ |
| | def __init__(self, data_file_paths, |
| | prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate: float = 0.05, test_data_rate: float = 0.05): |
| | """Construct the class our given data files path and store variables |
| | |
| | Args: |
| | data_file_paths (_type_): list of paths to data files |
| | prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
| | prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
| | prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True. |
| | verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
| | lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
| | val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
| | test_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.05. |
| | """ |
| | self._data_file_paths = data_file_paths |
| | self.prompt_length = prompt_length |
| | self.prompt_ending = prompt_ending |
| | self.prompt_verse = prompt_verse |
| | self.verse_len = verse_len |
| | self.lower_case = lower_case |
| | |
| | self.val_data_rate = val_data_rate |
| | self.test_data_rate = test_data_rate |
| | |
| | self.data = [] |
| | self.validation_data = [] |
| | self.test_data = [] |
| | |
| | def gen_files(self): |
| | """Get individual opened files |
| | |
| | Yields: |
| | _type_: open file object |
| | """ |
| | for filename in self._data_file_paths: |
| | yield open(filename, 'r') |
| | |
| | |
| | |
| | |
| | def _construct_line(self, raw_text, metre): |
| | """Construct individual content line |
| | |
| | Args: |
| | raw_text (str): raw verse line |
| | |
| | Returns: |
| | str: Processed verse line with line parameters |
| | """ |
| | syllables = SyllableMaker.syllabify(raw_text) |
| | num_str = f"{len(syllables)} # " if self.prompt_length else "" |
| | verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
| | metre_txt = f"{metre} # " |
| | return metre_txt + num_str + verse_end + raw_text |
| | |
| | def _construct_syllable_line(self, raw_text, metre): |
| | """Construct individual content line as sequence of syllables |
| | |
| | Args: |
| | raw_text (str): raw verse line |
| | |
| | Returns: |
| | str: Processed verse line as sequence of syllables with line parameters |
| | """ |
| | ending = raw_text[-1] if raw_text[-1] in [',','.','!','?'] else '' |
| | syllables = SyllableMaker.syllabify(raw_text) |
| | num_str = f"{len(syllables)} # " if self.prompt_length else "" |
| | verse_end = f"{syllables[-1]} # " if self.prompt_ending else "" |
| | metre_txt = f"{metre} # " |
| | return metre_txt + num_str + verse_end + " ".join(syllables) + ending |
| | |
| | |
| | |
| | def data_body_gen(self): |
| | """Preprocess and process data for usage |
| | """ |
| | for step,file in enumerate(self.gen_files()): |
| | if step % 500 == 0: |
| | print(f"Processing file {step}") |
| | datum = json.load(file) |
| | |
| | for data_line in datum: |
| | publish_year_text = TextManipulation._year_bucketor(data_line["biblio"]["year"]) |
| | publish_year_true = data_line["biblio"]["year"] if TextAnalysis._is_year(data_line["biblio"]["year"]) else 'NaN' |
| | context = ["NO CONTEXT"] |
| |
|
| | for part_line in data_line['body']: |
| | body = [] |
| | body_syllabs = [] |
| | rhyme= [] |
| | metres = [] |
| | i = 0 |
| | for text_line in part_line: |
| | |
| | |
| | metre = StropheParams.METER_TRANSLATE.get(text_line["metre"][0]["type"], "J") |
| | metres += [metre] |
| | |
| | rhyme.append(text_line["rhyme"]) |
| | |
| | scanned_text = TextManipulation._remove_most_nonchar(text_line["text"], self.lower_case) |
| |
|
| | body.append(self._construct_line(scanned_text,metre)) |
| | body_syllabs.append(self._construct_syllable_line(scanned_text,metre)) |
| | |
| | i+=1 |
| | |
| | if i in self.verse_len: |
| | |
| | rhyme_str = TextManipulation._rhyme_string(rhyme) |
| | |
| | text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body) + "\n" |
| | syllable_text = f"# {rhyme_str} # {publish_year_text}\n" + "\n".join(body_syllabs) + "\n" |
| | context_text= "\n".join(context) |
| | rand_split = np.random.rand() |
| | if rand_split > self.val_data_rate + self.test_data_rate: |
| | self.data.append({ |
| | "input_ids" : [text,syllable_text], |
| | "context_ids" : context_text, |
| | "year": publish_year_true, |
| | "rhyme": rhyme_str, |
| | "metre_ids" : metres.copy() |
| | }) |
| | elif rand_split < self.test_data_rate: |
| | self.test_data.append({ |
| | "input_ids" : [text,syllable_text], |
| | "context_ids" : context_text, |
| | "year": publish_year_true, |
| | "rhyme": rhyme_str, |
| | "metre_ids" : metres.copy() |
| | }) |
| | else: |
| | self.validation_data.append({ |
| | "input_ids" : [text,syllable_text], |
| | "context_ids" : context_text, |
| | "year": publish_year_true, |
| | "rhyme": rhyme_str, |
| | "metre_ids" : metres.copy() |
| | }) |
| | |
| | if i == max(self.verse_len): |
| | body = [] |
| | body_syllabs = [] |
| | rhyme = [] |
| | metres = [] |
| | i=0 |
| | |
| | |
| | def __len__(self): |
| | """Return length of training data |
| | |
| | Returns: |
| | int: length of training data |
| | """ |
| | return len(self.data) |
| | |
| | def __getitem__(self, index): |
| | """return indexed item |
| | |
| | Args: |
| | index (int): index from where to return |
| | |
| | Returns: |
| | dict: dict with indexed data |
| | """ |
| | return self.data[index] |
| | |
| | def get_filenames(self): |
| | """Get paths of data files |
| | |
| | Returns: |
| | list: Paths of data files |
| | """ |
| | data_filenames = os.listdir(self.data_dir) |
| | data_by_files = [] |
| | for filename in data_filenames: |
| | file_path = os.path.join(self.data_dir, filename) |
| | data_by_files.append(file_path) |
| | return data_by_files |
| | |
| | def load_raw_(self): |
| | """Load Raw dataset with raw string data |
| | """ |
| | filenames = self.get_filenames() |
| | |
| | self.raw_dataset = CorpusDatasetPytorch.RawDataset(filenames, self.lower_case) |
| | |
| | def load_json_filenames(self, prompt_length, prompt_ending, prompt_verse, verse_len=[4,6], val_data_rate=0.05, test_data_rate=0.05): |
| | """Load Verse and Strophe datasets |
| | |
| | Args: |
| | prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
| | prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
| | prompt_verse (bool, optional): If to prompt rhyme schema . Defaults to True. |
| | verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
| | val_data_rate (float, optional): If the string should be in lowercase. Defaults to 0.1. |
| | """ |
| | filenames = self.get_filenames() |
| | |
| | self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset(filenames, prompt_ending=prompt_ending, |
| | prompt_length=prompt_length, prompt_verse=prompt_verse, |
| | verse_len=verse_len, lower_case=self.lower_case, |
| | val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
| | self.pytorch_dataset_body.data_body_gen() |
| | |
| | |
| | self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset(filenames, prompt_ending=prompt_ending, |
| | prompt_length=prompt_length, lower_case=self.lower_case, |
| | val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
| | |
| | self.pytorch_dataset_text.data_text_line_gen() |
| | |
| | self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
| | self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
| | |
| | self.val_pytorch_dataset_body.data = self.pytorch_dataset_body.validation_data |
| | self.val_pytorch_dataset_text.data = self.pytorch_dataset_text.validation_data |
| | |
| | self.pytorch_dataset_text.validation_data = [] |
| | self.pytorch_dataset_body.validation_data = [] |
| | |
| | self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
| | self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
| | |
| | self.test_pytorch_dataset_body.data = self.pytorch_dataset_body.test_data |
| | self.test_pytorch_dataset_text.data = self.pytorch_dataset_text.test_data |
| | |
| | self.pytorch_dataset_text.test_data = [] |
| | self.pytorch_dataset_body.test_data = [] |
| | |
| | def create_empty(self): |
| | """Create empty holder for possible load of processed data from file |
| | """ |
| | self.pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
| | self.pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
| | self.val_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
| | self.val_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
| | self.test_pytorch_dataset_body = CorpusDatasetPytorch.BodyDataset([]) |
| | self.test_pytorch_dataset_text = CorpusDatasetPytorch.TextDataset([]) |
| | |
| | |
| | @staticmethod |
| | def collate(batch, tokenizer: PreTrainedTokenizerBase ,max_len = 1024, max_context = 1024 ,mask_rate = 0.0, syllables: bool = False, format: str = 'METER_VERSE'): |
| | """Process data for usage in LM |
| | |
| | Args: |
| | batch (_type_): Batch with selected data points |
| | tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text |
| | max_len (int, optional): Maximum length of tokenization. Defaults to 1024. |
| | max_context (int, optional): Maximum length of tokenization of context. Defaults to 1024. |
| | mask_rate (float, optional): Rate in with to mask data. Defaults to 0.0. |
| | syllables (bool, optional): If to use sequence of syllables as input text. Defaults to False. |
| | |
| | Returns: |
| | dict: tokenized and processed to tensors data |
| | """ |
| | index = 1 if syllables else 0 |
| | |
| | tokenizer.model_max_length = max_len |
| | if batch[0]['input_ids'][0].startswith("#"): |
| | |
| | data = [text['input_ids'][index] for text in batch] |
| | if format == "BASIC": |
| | data = ["\n".join |
| | ( |
| | [line + f" # {datum.splitlines()[1].split()[0]}" |
| | if i==0 else line.split('#')[-1] for i, line in enumerate(datum.splitlines())] |
| | ) + tokenizer.eos_token for j, datum in enumerate(data) |
| | ] |
| | elif format == "VERSE_PAR": |
| | data = ["\n".join |
| | ( |
| | [line + f" # {datum.splitlines()[1].split()[0]}" |
| | if i==0 else "#".join(line.split('#')[1:]) for i, line in enumerate(datum.splitlines())] |
| | ) + tokenizer.eos_token for j, datum in enumerate(data) |
| | ] |
| | else: |
| | data = [text['input_ids'][index] + tokenizer.eos_token for text in batch] |
| | |
| | tokenized = tokenizer(data,return_tensors='pt', truncation=True, padding=True) |
| | input_ids = tokenized['input_ids'] |
| | attention = tokenized["attention_mask"] |
| | |
| | else: |
| | tokenized = tokenizer([text['input_ids'][index] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True) |
| | input_ids = tokenized['input_ids'] |
| | attention = tokenized["attention_mask"] |
| | |
| | |
| | nums = None |
| | if "nums" in batch[0].keys(): |
| | nums = torch.tensor(np.asarray([text['nums'] for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | rhyme=None |
| | if "rhyme" in batch[0].keys(): |
| | rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | verse_end = None |
| | if "verse_end" in batch[0].keys(): |
| | verse_end = torch.tensor(np.asarray([CorpusDatasetPytorch.TextDataset._ending_vector(text["verse_end"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | year = None |
| | if "year" in batch[0].keys(): |
| | year = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | metre = None |
| | if "metre" in batch[0].keys(): |
| | metre = torch.tensor(np.asarray([TextAnalysis._metre_vector(text["metre"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | context_ids = None |
| | context_attention_mask = None |
| | if "context_ids" in batch[0].keys(): |
| | tokenizer.model_max_length = max_context |
| | tokenized_context = tokenizer([text['context_ids'] + tokenizer.eos_token for text in batch],return_tensors='pt', truncation=True, padding=True) |
| | context_ids = tokenized_context['input_ids'] |
| | context_attention_mask = tokenized_context['attention_mask'] |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "labels": input_ids.type(torch.LongTensor), |
| | "attention_mask": attention, |
| | "context_ids" : context_ids, |
| | "context_attention_mask" : context_attention_mask, |
| | "nums" : nums, |
| | "rhyme": rhyme, |
| | "verse_end" : verse_end, |
| | "year": year, |
| | "metre" : metre} |
| | |
| | |
| | @staticmethod |
| | def collate_distil(batch, tokenizer: PreTrainedTokenizerBase ,surrogate_model: PreTrainedModel = None,surrogate_model_device=None ,max_len = 1024): |
| | tokenizer.model_max_length = max_len |
| | tokenized = tokenizer([text['input_ids'][0] + tokenizer.eos_token for text in batch], return_tensors='pt', truncation=True, padding=True) |
| | input_ids = tokenized['input_ids'] |
| | attention = tokenized["attention_mask"] |
| | |
| | with torch.no_grad(): |
| | |
| | model_hidden_states = surrogate_model(input_ids=input_ids.to(surrogate_model_device), |
| | attention_mask=attention.to(surrogate_model_device), |
| | labels=input_ids.type(torch.LongTensor).to(surrogate_model_device))['hidden_states'] |
| | model_hidden_states = [hidden.cpu().detach() for hidden in model_hidden_states] |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "labels": input_ids.type(torch.LongTensor), |
| | "attention_mask": attention, |
| | "to_replicate_states": model_hidden_states |
| | } |
| | |
| | @staticmethod |
| | def collate_validator(batch, tokenizer: PreTrainedTokenizerBase,syllables:bool, is_syllable:bool = False,max_len = 512): |
| | """Process data for use in LM for metre,rhyme and year prediction |
| | |
| | Args: |
| | batch (_type_): Batch with selected data points |
| | tokenizer (PreTrainedTokenizerBase): tokenizer to tokenize input text |
| | syllables (bool): If to use sequence of syllables as input text |
| | is_syllable (bool, optional): Signal if the preprocessed inputs contain syllable data. Defaults to False. |
| | max_len (int, optional): Maximum length of tokenization. Defaults to 1024. |
| | |
| | Returns: |
| | dict: tokenized and processed to tensors data |
| | """ |
| | index = 1 if syllables and is_syllable else 0 |
| | tokenizer.model_max_length = max_len |
| | data_ids = ["\n".join( |
| | [" ".join( |
| | SyllableMaker.syllabify(line.split('#')[-1]) |
| | ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in text['input_ids'][index].splitlines()[1:]] |
| | ) for text in batch ] |
| | |
| | |
| | tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True) |
| | input_ids = tokenized['input_ids'] |
| | attention = tokenized["attention_mask"] |
| | |
| | rhyme=None |
| | if "rhyme" in batch[0].keys(): |
| | rhyme = torch.tensor(np.asarray([TextAnalysis._rhyme_vector(text["rhyme"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | year_bucket = None |
| | year = None |
| | if "year" in batch[0].keys(): |
| | year_bucket = torch.tensor(np.asarray([TextAnalysis._publish_year_vector(text["year"]) for text in batch], dtype=np.int32), dtype=torch.float32) |
| | year = torch.tensor(np.asarray([ [int(text['year'])] if text['year'] != 'NaN' else [0] for text in batch], dtype=np.int32), dtype=torch.float32) |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention, |
| | "rhyme": rhyme, |
| | "metre_ids": None, |
| | "year_bucket": year_bucket, |
| | 'year':year} |
| | |
| | @staticmethod |
| | def collate_meter(batch, tokenizer: PreTrainedTokenizerBase, syllables:bool, is_syllable:bool = False, max_len = 512): |
| | index = 1 if syllables and is_syllable else 0 |
| | tokenizer.model_max_length = max_len |
| | data_ids = [] |
| | metre = [] |
| | for datum in batch: |
| | data_ids += [ |
| | " ".join( |
| | SyllableMaker.syllabify(line.split('#')[-1]) |
| | ) + (line[-1] if line[-1] in [',','.','!','?'] else '') if (syllables and not is_syllable and line) else line.split('#')[-1] for line in datum['input_ids'][index].splitlines()[1:] |
| | ] |
| | if "metre_ids" in batch[0].keys(): |
| | metre += [TextAnalysis._metre_vector(one_metre) for one_metre in datum['metre_ids']] |
| | |
| | tokenized = tokenizer(data_ids, return_tensors='pt', truncation=True, padding=True) |
| | input_ids = tokenized['input_ids'] |
| | attention = tokenized["attention_mask"] |
| | |
| | metre_ids = None |
| | if len(metre) > 0: |
| | metre_ids = torch.tensor(np.asarray(metre, dtype=np.int32), dtype=torch.float32) |
| | |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention, |
| | "rhyme": None, |
| | "metre_ids": metre_ids, |
| | "year_bucket": None, |
| | "year": None} |
| | |
| | |
| | |
| | def __init__(self, data_dir = "PoetGen\corpusCzechVerse-master\ccv", cache_dir='./', |
| | prompt_length=True, prompt_ending=True, prompt_verse=True, verse_len=[4,6], lower_case=True, val_data_rate=0.05, test_data_rate=0.05): |
| | """Construct the Dataloader and create Datasets |
| | |
| | Args: |
| | data_dir (str, optional): Path to data. Defaults to "PoetGen\corpusCzechVerse-master\ccv". |
| | cache_dir (str, optional): Path where to store processed data. Defaults to './'. |
| | prompt_length (bool, optional): If to prompt the syllable count. Defaults to True. |
| | prompt_ending (bool, optional): If to prompt verse ending. Defaults to True. |
| | prompt_verse (bool, optional): If to prompt rhyme schema. Defaults to True. |
| | verse_len (list, optional): Considered length of strophe. Defaults to [4,6]. |
| | lower_case (bool, optional): If the string should be in lowercase. Defaults to True. |
| | val_data_rate (float, optional): Amount of data to be left for validation. Defaults to 0.1. |
| | """ |
| | self.lower_case = lower_case |
| | self.data_dir = data_dir |
| | if os.path.isfile(os.path.join(cache_dir, "body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "text_poet_data.json")) \ |
| | and os.path.isfile(os.path.join(cache_dir, "val_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "val_text_poet_data.json")) \ |
| | and os.path.isfile(os.path.join(cache_dir, "test_body_poet_data.json")) and os.path.isfile(os.path.join(cache_dir, "test_text_poet_data.json")) : |
| | self.create_empty() |
| | self.pytorch_dataset_body.data =list(json.load( open( os.path.join(cache_dir, "body_poet_data.json"), 'r'))) |
| | self.pytorch_dataset_text.data =list(json.load( open( os.path.join(cache_dir, "text_poet_data.json"), 'r'))) |
| | self.val_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "val_body_poet_data.json"), 'r'))) |
| | self.val_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "val_text_poet_data.json"), 'r'))) |
| | self.test_pytorch_dataset_body.data = list(json.load( open( os.path.join(cache_dir, "test_body_poet_data.json"), 'r'))) |
| | self.test_pytorch_dataset_text.data = list(json.load( open( os.path.join(cache_dir, "test_text_poet_data.json"), 'r'))) |
| | else: |
| | self.load_json_filenames(prompt_length, prompt_ending, prompt_verse, verse_len=verse_len, val_data_rate=val_data_rate, test_data_rate=test_data_rate) |
| | json.dump(self.pytorch_dataset_body.data, open( os.path.join(cache_dir, "body_poet_data.json"), 'w+'), indent = 6) |
| | json.dump(self.pytorch_dataset_text.data, open( os.path.join(cache_dir, "text_poet_data.json"), 'w+'), indent = 6) |
| | json.dump(self.val_pytorch_dataset_body.data, open( os.path.join(cache_dir, "val_body_poet_data.json"), 'w+'), indent = 6) |
| | json.dump(self.val_pytorch_dataset_text.data, open( os.path.join(cache_dir, "val_text_poet_data.json"), 'w+'), indent = 6) |
| | json.dump(self.test_pytorch_dataset_body.data, open( os.path.join(cache_dir, "test_body_poet_data.json"), 'w+'), indent = 6) |
| | json.dump(self.test_pytorch_dataset_text.data, open( os.path.join(cache_dir, "test_text_poet_data.json"), 'w+'), indent = 6) |
| | |
| | self.load_raw_() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |