Spaces:
No application file
No application file
| import re | |
| from transformers import GPT2TokenizerFast, GPTNeoXTokenizerFast | |
| from datasets import Dataset, load_dataset | |
| from typing import Literal, List | |
| TEXT_DATASETS = ["wikitext2", "openwebtext"] | |
| MIN_LEN = 50 | |
| def wt_detokeniser(string): | |
| # contractions | |
| string = string.replace("s '", "s'") | |
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |
| # number separators | |
| string = string.replace(" @-@ ", "-") | |
| string = string.replace(" @,@ ", ",") | |
| string = string.replace(" @.@ ", ".") | |
| # punctuation | |
| string = string.replace(" : ", ": ") | |
| string = string.replace(" ; ", "; ") | |
| string = string.replace(" . ", ". ") | |
| string = string.replace(" ! ", "! ") | |
| string = string.replace(" ? ", "? ") | |
| string = string.replace(" , ", ", ") | |
| # double brackets | |
| string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) | |
| string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) | |
| string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) | |
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |
| # miscellaneous | |
| string = string.replace("= = = =", "====") | |
| string = string.replace("= = =", "===") | |
| string = string.replace("= =", "==") | |
| string = string.replace(" " + chr(176) + " ", chr(176)) | |
| string = string.replace(" \n", "\n") | |
| string = string.replace("\n ", "\n") | |
| string = string.replace(" N ", " 1 ") | |
| string = string.replace(" 's", "'s") | |
| return string | |
| def setup_tokeniser(tokeniser_name: str) -> GPT2TokenizerFast: | |
| match tokeniser_name: | |
| case "gpt2": | |
| tokeniser = GPT2TokenizerFast.from_pretrained("gpt2") | |
| case "gpt-neo": | |
| tokeniser = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b") | |
| case _: | |
| raise ValueError(f"Tokeniser {tokeniser_name} not supported") | |
| tokeniser.add_special_tokens( | |
| { | |
| "pad_token": "[PAD]", | |
| "mask_token": "[MASK]", | |
| } | |
| ) | |
| return tokeniser | |
| def find_delimiter_positions(tokens, delimiter_tokens): | |
| """Return the start indices where the delimiter occurs in the token sequence.""" | |
| positions = [] | |
| n = len(delimiter_tokens) | |
| for i in range(len(tokens) - n + 1): | |
| if tokens[i : i + n] == delimiter_tokens: | |
| positions.append(i) | |
| return positions | |
| def recursive_split(tokens, max_length, delimiter_tokens): | |
| if len(tokens) <= max_length: | |
| return [tokens] | |
| # Find all positions where the delimiter sequence occurs | |
| split_candidates = find_delimiter_positions(tokens, delimiter_tokens) | |
| if not split_candidates: | |
| # Safe fallback: naive split | |
| return [ | |
| tokens[i : min(i + max_length, len(tokens))] | |
| for i in range(0, len(tokens), max_length) | |
| ] | |
| # Find delimiter closest to the midpoint | |
| midpoint = len(tokens) // 2 | |
| split_point = min(split_candidates, key=lambda x: abs(x - midpoint)) | |
| # Recurse on both sides, skipping the delimiter | |
| dlen = len(delimiter_tokens) | |
| left = recursive_split(tokens[:split_point], max_length, delimiter_tokens) | |
| right = recursive_split(tokens[split_point + dlen :], max_length, delimiter_tokens) | |
| return left + right | |
| def preprocess_batch(batch, pad_token, max_length, delimiter, detokeniser, tokeniser): | |
| all_input_ids = [] | |
| all_lengths = [] | |
| for text in batch["text"]: | |
| if detokeniser is not None: | |
| text = detokeniser(text) | |
| tokens = tokeniser.encode(text, add_special_tokens=False) | |
| chunks = recursive_split(tokens, max_length, delimiter) | |
| all_input_ids.extend( | |
| [ | |
| c + [pad_token] * (max_length - len(c)) if len(c) < max_length else c | |
| for c in chunks | |
| ] | |
| ) | |
| all_lengths.extend([len(chunk) for chunk in chunks]) | |
| return { | |
| "input_ids": all_input_ids, | |
| "length": all_lengths, | |
| } | |
| def setup_tokeniser_from_dataset(dataset_name: str): | |
| tokeniser = None | |
| match dataset_name: | |
| case "wikitext2" | "openwebtext": | |
| tokeniser = setup_tokeniser("gpt2") | |
| case "dclm": | |
| tokeniser = setup_tokeniser("gpt-neo") | |
| case _: | |
| raise ValueError(f"Tokeniser for dataset {dataset_name} not supported") | |
| return tokeniser | |
| def decode_sequence_with_mask( | |
| seqs: List[List[int]], tokeniser: GPT2TokenizerFast, pad_token: int, mask_token: int | |
| ) -> List[str]: | |
| """ | |
| Decode a sequence with visible mask tokens. | |
| """ | |
| decoded = [] | |
| for seq in seqs: | |
| tokens = tokeniser.convert_ids_to_tokens(seq) | |
| filtered = [] | |
| for tok, tok_id in zip(tokens, seq): | |
| if tok_id == pad_token: | |
| continue | |
| if tok_id == mask_token: | |
| filtered.append("[MASK]") | |
| else: | |
| filtered.append(tok) | |
| text = tokeniser.convert_tokens_to_string(filtered) | |
| decoded.append(text) | |
| return decoded | |
| def get_text_dataset( | |
| name: str, | |
| split: Literal["train", "validation", "test"], | |
| cache_dir=None, | |
| max_length=1024, | |
| num_proc=64, | |
| filter_max_length=True, | |
| ) -> Dataset: | |
| match name: | |
| case "wikitext2": | |
| dataset = load_dataset( | |
| "wikitext", "wikitext-2-raw-v1", cache_dir=cache_dir, split=split | |
| ) | |
| case "openwebtext": | |
| ds_all = load_dataset(name, cache_dir=cache_dir) | |
| train_ds = ds_all["train"] | |
| if split in ["train", "validation"]: | |
| split_data = train_ds.train_test_split(test_size=0.02, seed=42) | |
| dataset = ( | |
| split_data["train"] if split == "train" else split_data["test"] | |
| ) | |
| else: | |
| raise ValueError(f"Dataset {name} does not support split {split}") | |
| case _: | |
| raise ValueError(f"Dataset {name} not supported") | |
| match name: | |
| case "wikitext2": | |
| detokeniser = wt_detokeniser | |
| case "openwebtext": | |
| detokeniser = None | |
| case "dclm": | |
| detokeniser = None | |
| case _: | |
| raise ValueError(f"Dataset {name} not supported") | |
| tokeniser = setup_tokeniser_from_dataset(name) | |
| pad_token = tokeniser.pad_token_id | |
| if filter_max_length: | |
| def preprocess(sample): | |
| text = sample["text"] | |
| if detokeniser is not None: | |
| text = detokeniser(text) | |
| text = tokeniser(text, return_attention_mask=False) | |
| if len(text["input_ids"]) < MIN_LEN: | |
| return {"input_ids": []} | |
| text["input_ids"] += max(0, max_length - len(text["input_ids"])) * [ | |
| pad_token | |
| ] | |
| return text | |
| tokenised_dataset = dataset.map( | |
| preprocess, | |
| num_proc=num_proc, | |
| load_from_cache_file=True, | |
| remove_columns=["text"], | |
| ) | |
| tokenised_dataset = tokenised_dataset.filter( | |
| lambda x: 0 < len(x["input_ids"]) <= max_length, | |
| num_proc=num_proc, | |
| load_from_cache_file=True, | |
| ) | |
| tokenised_dataset = tokenised_dataset.with_format("torch") | |
| return tokenised_dataset | |
| else: | |
| tokenised_dataset = dataset.map( | |
| lambda batch: preprocess_batch( | |
| batch, | |
| pad_token=pad_token, | |
| max_length=max_length, | |
| detokeniser=detokeniser, | |
| tokeniser=tokeniser, | |
| delimiter=[198, 198], | |
| ), | |
| batched=True, | |
| num_proc=num_proc, | |
| remove_columns=["text"], | |
| ) | |
| tokenised_dataset = tokenised_dataset.with_format("torch") | |
| return tokenised_dataset | |