Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import re
from transformers import GPT2TokenizerFast, GPTNeoXTokenizerFast
from datasets import Dataset, load_dataset
from typing import Literal, List
TEXT_DATASETS = ["wikitext2", "openwebtext"]
MIN_LEN = 50
def wt_detokeniser(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def setup_tokeniser(tokeniser_name: str) -> GPT2TokenizerFast:
match tokeniser_name:
case "gpt2":
tokeniser = GPT2TokenizerFast.from_pretrained("gpt2")
case "gpt-neo":
tokeniser = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
case _:
raise ValueError(f"Tokeniser {tokeniser_name} not supported")
tokeniser.add_special_tokens(
{
"pad_token": "[PAD]",
"mask_token": "[MASK]",
}
)
return tokeniser
def find_delimiter_positions(tokens, delimiter_tokens):
"""Return the start indices where the delimiter occurs in the token sequence."""
positions = []
n = len(delimiter_tokens)
for i in range(len(tokens) - n + 1):
if tokens[i : i + n] == delimiter_tokens:
positions.append(i)
return positions
def recursive_split(tokens, max_length, delimiter_tokens):
if len(tokens) <= max_length:
return [tokens]
# Find all positions where the delimiter sequence occurs
split_candidates = find_delimiter_positions(tokens, delimiter_tokens)
if not split_candidates:
# Safe fallback: naive split
return [
tokens[i : min(i + max_length, len(tokens))]
for i in range(0, len(tokens), max_length)
]
# Find delimiter closest to the midpoint
midpoint = len(tokens) // 2
split_point = min(split_candidates, key=lambda x: abs(x - midpoint))
# Recurse on both sides, skipping the delimiter
dlen = len(delimiter_tokens)
left = recursive_split(tokens[:split_point], max_length, delimiter_tokens)
right = recursive_split(tokens[split_point + dlen :], max_length, delimiter_tokens)
return left + right
def preprocess_batch(batch, pad_token, max_length, delimiter, detokeniser, tokeniser):
all_input_ids = []
all_lengths = []
for text in batch["text"]:
if detokeniser is not None:
text = detokeniser(text)
tokens = tokeniser.encode(text, add_special_tokens=False)
chunks = recursive_split(tokens, max_length, delimiter)
all_input_ids.extend(
[
c + [pad_token] * (max_length - len(c)) if len(c) < max_length else c
for c in chunks
]
)
all_lengths.extend([len(chunk) for chunk in chunks])
return {
"input_ids": all_input_ids,
"length": all_lengths,
}
def setup_tokeniser_from_dataset(dataset_name: str):
tokeniser = None
match dataset_name:
case "wikitext2" | "openwebtext":
tokeniser = setup_tokeniser("gpt2")
case "dclm":
tokeniser = setup_tokeniser("gpt-neo")
case _:
raise ValueError(f"Tokeniser for dataset {dataset_name} not supported")
return tokeniser
def decode_sequence_with_mask(
seqs: List[List[int]], tokeniser: GPT2TokenizerFast, pad_token: int, mask_token: int
) -> List[str]:
"""
Decode a sequence with visible mask tokens.
"""
decoded = []
for seq in seqs:
tokens = tokeniser.convert_ids_to_tokens(seq)
filtered = []
for tok, tok_id in zip(tokens, seq):
if tok_id == pad_token:
continue
if tok_id == mask_token:
filtered.append("[MASK]")
else:
filtered.append(tok)
text = tokeniser.convert_tokens_to_string(filtered)
decoded.append(text)
return decoded
def get_text_dataset(
name: str,
split: Literal["train", "validation", "test"],
cache_dir=None,
max_length=1024,
num_proc=64,
filter_max_length=True,
) -> Dataset:
match name:
case "wikitext2":
dataset = load_dataset(
"wikitext", "wikitext-2-raw-v1", cache_dir=cache_dir, split=split
)
case "openwebtext":
ds_all = load_dataset(name, cache_dir=cache_dir)
train_ds = ds_all["train"]
if split in ["train", "validation"]:
split_data = train_ds.train_test_split(test_size=0.02, seed=42)
dataset = (
split_data["train"] if split == "train" else split_data["test"]
)
else:
raise ValueError(f"Dataset {name} does not support split {split}")
case _:
raise ValueError(f"Dataset {name} not supported")
match name:
case "wikitext2":
detokeniser = wt_detokeniser
case "openwebtext":
detokeniser = None
case "dclm":
detokeniser = None
case _:
raise ValueError(f"Dataset {name} not supported")
tokeniser = setup_tokeniser_from_dataset(name)
pad_token = tokeniser.pad_token_id
if filter_max_length:
def preprocess(sample):
text = sample["text"]
if detokeniser is not None:
text = detokeniser(text)
text = tokeniser(text, return_attention_mask=False)
if len(text["input_ids"]) < MIN_LEN:
return {"input_ids": []}
text["input_ids"] += max(0, max_length - len(text["input_ids"])) * [
pad_token
]
return text
tokenised_dataset = dataset.map(
preprocess,
num_proc=num_proc,
load_from_cache_file=True,
remove_columns=["text"],
)
tokenised_dataset = tokenised_dataset.filter(
lambda x: 0 < len(x["input_ids"]) <= max_length,
num_proc=num_proc,
load_from_cache_file=True,
)
tokenised_dataset = tokenised_dataset.with_format("torch")
return tokenised_dataset
else:
tokenised_dataset = dataset.map(
lambda batch: preprocess_batch(
batch,
pad_token=pad_token,
max_length=max_length,
detokeniser=detokeniser,
tokeniser=tokeniser,
delimiter=[198, 198],
),
batched=True,
num_proc=num_proc,
remove_columns=["text"],
)
tokenised_dataset = tokenised_dataset.with_format("torch")
return tokenised_dataset