| | from collections import Counter |
| | import torchvision.datasets as dset |
| | from torch.utils.data import Dataset |
| | import torch |
| | from torch.utils.data import DataLoader |
| | import glob |
| | import os |
| | from torch.utils.data import Dataset, DataLoader, random_split |
| | from shutil import copyfile |
| | import subprocess |
| | import youtokentome as yttm |
| | import re |
| | import time |
| | from tqdm import trange, tqdm |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import inspect |
| |
|
| | |
| |
|
| | DEVICE = "cpu" |
| |
|
| |
|
| | class BPEModelManager: |
| | def __init__(self, root_dir, vocab_size=5000): |
| | self.root_dir = root_dir |
| | self.vocab_size = vocab_size |
| | self.model_path = os.path.join(root_dir, "bpe_model.model") |
| |
|
| | try: |
| | self.bpe = yttm.BPE(model=self.model_path) |
| | if self.bpe.vocab_size() != vocab_size: |
| | print( |
| | f"Vocab size mismatch: Expected {vocab_size}, got {self.bpe.vocab_size()}. Retraining model." |
| | ) |
| | self._backup_model() |
| | raise ValueError |
| | except ValueError: |
| | self._train_bpe_model() |
| | self.bpe = yttm.BPE(model=self.model_path) |
| |
|
| | def _backup_model(self): |
| | backup_path = os.path.join(self.root_dir, "bpe_model.model.old") |
| | copyfile(self.model_path, backup_path) |
| |
|
| | def _train_bpe_model(self): |
| | data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| | processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
| |
|
| | with open(data_path, "r", errors="ignore") as reader: |
| | raw_text = reader.read() |
| |
|
| | processed_text = self.preprocess_text(raw_text) |
| |
|
| | with open(processed_path, "w") as writer: |
| | writer.write(processed_text) |
| |
|
| | yttm.BPE.train( |
| | data=processed_path, |
| | vocab_size=self.vocab_size, |
| | model=self.model_path, |
| | coverage=0.9999, |
| | ) |
| |
|
| | def preprocess_text(self, text): |
| | return text.lower() |
| |
|
| | def encode(self, text: str): |
| | return self.bpe.encode([text], output_type=yttm.OutputType.ID) |
| |
|
| | def decode(self, ids): |
| | return self.bpe.decode(ids.tolist())[0] |
| |
|
| | @staticmethod |
| | def attention_mask(encoded_sequence, mask_token_ids=[0, 1, 2, 3]): |
| | mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( |
| | encoded_sequence.device |
| | ) |
| | |
| | |
| | return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
| |
|
| |
|
| | class CodeBPEModelManager(BPEModelManager): |
| | mapping_dict = { |
| | " ": " <INDENT> ", |
| | "\n": " <NEWLINE> ", |
| | } |
| |
|
| | def __init__(self, root_dir, vocab_size=5000): |
| | super().__init__(root_dir, vocab_size) |
| |
|
| | def preprocess_text(self, text): |
| | print("Formatting....") |
| | processed_text = self.format_code(text) |
| |
|
| | for key, value in CodeBPEModelManager.mapping_dict.items(): |
| | processed_text = processed_text.replace(key, value) |
| |
|
| | return processed_text |
| |
|
| | def encode(self, text: str): |
| | processed_text = text |
| | for key, value in CodeBPEModelManager.mapping_dict.items(): |
| | processed_text = processed_text.replace(key, value) |
| |
|
| | return self.bpe.encode([processed_text], output_type=yttm.OutputType.ID)[0] |
| |
|
| | def decode(self, ids): |
| | |
| | |
| | l = ids |
| | if isinstance(l, torch.Tensor): |
| | l = ids.tolist() |
| | if isinstance(l, int): |
| | l = [l] |
| |
|
| | result = self.bpe.decode(l)[0] |
| | |
| | for key, value in CodeBPEModelManager.mapping_dict.items(): |
| | result = result.replace(value.strip(), key) |
| |
|
| | return result |
| |
|
| | def raw_decode(self, id: int): |
| | return self.bpe.decode([id])[0] |
| |
|
| | def _train_bpe_model(self): |
| | print("Training (1)....") |
| | data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| | processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
| |
|
| | if input("Reformat? Will take time [y/N]") == "y": |
| |
|
| | with open(data_path, "r", errors="ignore", encoding="utf-8") as reader: |
| | raw_text = reader.read() |
| |
|
| | processed_text = self.preprocess_text(raw_text) |
| |
|
| | with open(processed_path, "w", encoding="utf-8") as writer: |
| | writer.write(processed_text) |
| |
|
| | print("removing temp file...") |
| | temp_file = os.path.join(self.root_dir, "temp_code.py") |
| | os.remove(temp_file) |
| |
|
| | print("Training....") |
| | yttm.BPE.train( |
| | data=processed_path, |
| | vocab_size=self.vocab_size, |
| | model=self.model_path, |
| | coverage=1, |
| | |
| | ) |
| |
|
| | def format_code(self, code): |
| | try: |
| | temp_file = os.path.join(self.root_dir, "temp_code.py") |
| | with open(temp_file, "w") as file: |
| | file.write( |
| | code.replace("\t", " ") |
| | ) |
| |
|
| | subprocess.run(["black", temp_file, "--quiet"], check=True) |
| | subprocess.run( |
| | ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True |
| | ) |
| |
|
| | with open(temp_file, "r") as file: |
| | formatted_code = file.read() |
| |
|
| | return formatted_code |
| | except Exception as e: |
| | print(f"Error during code formatting: {e}.") |
| | return code |
| |
|
| |
|
| | class CodeCustomTokenizerManager(BPEModelManager): |
| | reserved_keywords = [ |
| | "false", |
| | "await", |
| | "else", |
| | "import", |
| | "pass", |
| | "none", |
| | "break", |
| | "except", |
| | "in", |
| | "raise", |
| | "true", |
| | "class", |
| | "finally", |
| | "is", |
| | "return", |
| | "and", |
| | "continue", |
| | "for", |
| | "lambda", |
| | "try", |
| | "as", |
| | "def", |
| | "from", |
| | "nonlocal", |
| | "while", |
| | "assert", |
| | "del", |
| | "global", |
| | "not", |
| | "with", |
| | "async", |
| | "elif", |
| | "if", |
| | "or", |
| | "yield", |
| | ] |
| | symbols = [ |
| | "(", |
| | ")", |
| | "[", |
| | "]", |
| | "{", |
| | "}", |
| | ".", |
| | ",", |
| | ":", |
| | ";", |
| | "+", |
| | "-", |
| | "*", |
| | "/", |
| | "%", |
| | "=", |
| | "<", |
| | ">", |
| | "&", |
| | "|", |
| | "^", |
| | "~", |
| | "!", |
| | "==", |
| | "!=", |
| | "<=", |
| | ">=", |
| | "**", |
| | "//", |
| | "@", |
| | "#", |
| | "\\", |
| | "'", |
| | '"', |
| | "`", |
| | "0", |
| | "1", |
| | "2", |
| | "3", |
| | "4", |
| | "5", |
| | "6", |
| | "7", |
| | "8", |
| | "9", |
| | "0x", |
| | "0d", |
| | "0o", |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | root_dir, |
| | vocab_size=5000, |
| | cutoff_thresh=0.1, |
| | use_vocab_size_instead=False, |
| | use_whitespace=True, |
| | ): |
| | self.root_dir = root_dir |
| |
|
| | self.token_to_id = {"<PAD>": 0} |
| | self.id_to_token = None |
| |
|
| | self._token_freqs = {} |
| | self.total_num_tokens = 0 |
| | print("This is CodeCustomTokenizerManager, vocab size will be disregarded.") |
| |
|
| | print(f"Cutoff threshold: {cutoff_thresh}") |
| | self.cutoff_thresh = cutoff_thresh |
| |
|
| | self.use_whitespace = use_whitespace |
| |
|
| | if not use_whitespace: |
| | print("Not using whitespace! Important I guess") |
| |
|
| | if use_vocab_size_instead: |
| | print("Nevermind! Using vocab size instead, no cutoff thresh") |
| |
|
| | self.use_vocab_size_instead = use_vocab_size_instead |
| |
|
| | self.vocab_size = vocab_size |
| |
|
| | vocab_path = os.path.join(self.root_dir, "custom_tokens_vocab.txt") |
| | try: |
| | self.load_vocab(vocab_path) |
| | except FileNotFoundError: |
| | print("Making vocab!") |
| | self.make_vocab() |
| | self.save_vocab(vocab_path) |
| |
|
| | print(f"Vocab size: {len(self.token_to_id)}") |
| |
|
| | def make_vocab(self): |
| | data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| | processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
| |
|
| | with open(data_path, "r", errors="ignore") as reader: |
| | raw_text = reader.read() |
| |
|
| | processed_text = self.preprocess_text(raw_text) |
| |
|
| | with open(processed_path, "w") as writer: |
| | writer.write(" ".join(processed_text)) |
| |
|
| | for token in processed_text: |
| | if token not in self.token_to_id: |
| | if len(self.token_to_id) == 0: |
| | self.token_to_id = {"<PAD>": 0} |
| |
|
| | self.token_to_id[token] = len(self.token_to_id) |
| |
|
| | print(f"Number of tokens: {len(self.token_to_id)}") |
| | |
| | def make_token_freqs(self): |
| |
|
| | processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
| | with open(processed_path, "r", errors="ignore") as reader: |
| | raw_text = reader.read() |
| | tokens = raw_text.split(" ") |
| |
|
| | token_freqs = {"<PAD>": 0} |
| |
|
| |
|
| | for token in tqdm(tokens, leave=False): |
| | if token not in token_freqs: |
| | token_freqs[token] = 1 |
| | else: |
| | token_freqs[token] += 1 |
| | |
| | self._token_freqs = token_freqs |
| | self.total_num_tokens = len(tokens) |
| |
|
| |
|
| | def preprocess_text(self, code): |
| | print("Preprocessing text...", code[:20]) |
| |
|
| | |
| |
|
| | |
| | code = code.replace("# <FILESEP>", "<FILESEP>") |
| | code = re.sub(r"#.*", "", code) |
| | code = re.sub(r'"""(.*?)"""', "", code, flags=re.DOTALL) |
| | code = re.sub(r"'''(.*?)'''", "", code, flags=re.DOTALL) |
| |
|
| | code = re.sub(r" ", " ", code) |
| |
|
| | print("Filtered comments") |
| |
|
| | |
| |
|
| | |
| | |
| | code = re.sub(r"[^ -~\s]+", "", code) |
| | |
| | print("Filtered non-ascii") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | for word in self.reserved_keywords: |
| | code = re.sub(rf"\b{word}\b", f" {word} ", code) |
| |
|
| | print("Reserved words") |
| | for symbol in self.symbols: |
| | code = code.replace(symbol, f" {symbol} ") |
| |
|
| | print("Symbols") |
| |
|
| | |
| |
|
| | |
| | def split_token(token): |
| | if token.startswith("<") and token.endswith( |
| | ">" |
| | ): |
| | return [token.lower()] |
| | result = re.sub(r"([a-z])([A-Z])", r"\1 \2", token) |
| | result = re.sub(r"([_-])", r" \1 ", result) |
| | result = re.sub(r"([^a-zA-Z])", r" \1 ", result) |
| | return [part.lower() for part in result.split() if part.strip()] |
| |
|
| | code = code.replace(" ", " <TAB> ").replace("\n", " <NEWLINE> ") |
| | if not self.use_whitespace: |
| | code = code.replace("<TAB>", "").replace("<NEWLINE>", "") |
| | print("Tabs + newlines") |
| |
|
| | tokens = [] |
| | for token in tqdm(code.split(" "), leave=False): |
| | if token.strip(): |
| | tokens.extend(split_token(token)) |
| |
|
| | tokens = [tok.lower() for tok in tokens if tok.strip()] |
| |
|
| | print("Split tokens") |
| | token_freqs = {"<PAD>": 0} |
| | for token in tqdm(tokens, leave=False): |
| | if token not in token_freqs: |
| | token_freqs[token] = 1 |
| | else: |
| | token_freqs[token] += 1 |
| | print("Counted freqs") |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | total_num_tokens = len(tokens) |
| |
|
| | counter = Counter(list(token_freqs.values())) |
| | num_ones = counter[1] |
| | print( |
| | f"Number of tokens that appear only once: {num_ones}. Percentage: {num_ones / total_num_tokens}" |
| | ) |
| |
|
| | print(f"Mean token count: {np.mean(list(token_freqs.values()))}") |
| | print(f"Median token count: {np.median(list(token_freqs.values()))}") |
| |
|
| | print( |
| | f"Standard deviation of token count: {np.std(list(token_freqs.values()))}" |
| | ) |
| |
|
| | print(f"Min token count: {np.min(list(token_freqs.values()))}") |
| | print(f"Max token count: {np.max(list(token_freqs.values()))}") |
| |
|
| | print(f"Top 30 most frequent tokens:") |
| | sorted_tokens = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True) |
| | for token, freq in sorted_tokens[:30]: |
| | print(f"{token}: {freq}") |
| |
|
| | print(f"Bottom 30 most frequent tokens:") |
| | for token, freq in sorted_tokens[-30:]: |
| | print(f"{token}: {freq}") |
| |
|
| | self._token_freqs = token_freqs |
| | self.total_num_tokens = total_num_tokens |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | cutoff_thresh = self.cutoff_thresh |
| | if self.use_vocab_size_instead: |
| | print("Using vocab size instead") |
| | print("deprecated") |
| | print("cope") |
| | exit() |
| | sorted_tokens = sorted( |
| | token_freqs.items(), key=lambda x: x[1], reverse=True |
| | ) |
| | allowed_tokens = set( |
| | token for token, _ in sorted_tokens[: self.vocab_size - 1] |
| | ) |
| | for i in range(len(tokens)): |
| | if tokens[i] not in allowed_tokens and tokens[i] != "<PAD>": |
| | print(f"Replacing token with UNK: {tokens[i]}") |
| | tokens[i] = "<UNK>" |
| |
|
| | else: |
| | cutoff_amt = ( |
| | 10 |
| | ) |
| | print(f"Cuttoff amount: {cutoff_amt}") |
| |
|
| | |
| | low_freq_tokens = [ |
| | token |
| | for token, freq in token_freqs.items() |
| | if freq < cutoff_amt and token != "<PAD>" |
| | ] |
| | low_freq_tokens_set = set(low_freq_tokens) |
| | tokens = [ |
| | "<UNK>" if token in low_freq_tokens_set else token |
| | for token in tqdm(tokens) |
| | ] |
| |
|
| | print(tokens[500:700]) |
| |
|
| | print("500-700") |
| |
|
| | return [tok for tok in tokens if tok.strip()] |
| |
|
| | def encode(self, code): |
| | tokens = code.split(" ") |
| | ids = [] |
| |
|
| | for token in tokens: |
| | |
| | if token not in self.token_to_id: |
| | self.token_to_id[token] = len(self.token_to_id) |
| | ids.append(self.token_to_id[token]) |
| |
|
| | return ids |
| |
|
| | def decode(self, ids): |
| | result = "" |
| | for id in ids.tolist(): |
| | for token, id_iterator in self.token_to_id.items(): |
| | if id_iterator == id: |
| | result += token |
| | result += " " |
| |
|
| | return result |
| |
|
| | def raw_decode(self, id: int): |
| | for token, id_iterator in self.token_to_id.items(): |
| | if id_iterator == id: |
| | return token |
| |
|
| | def format_code(self, code): |
| | try: |
| | temp_file = os.path.join(self.root_dir, "temp_code.py") |
| | with open(temp_file, "w") as file: |
| | file.write( |
| | code.replace("\t", " ") |
| | ) |
| |
|
| | subprocess.run(["black", temp_file, "--quiet"], check=True) |
| | subprocess.run( |
| | ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True |
| | ) |
| |
|
| | with open(temp_file, "r") as file: |
| | formatted_code = file.read() |
| |
|
| | return formatted_code |
| | except Exception as e: |
| | print(f"Error during code formatting: {e}.") |
| | return code |
| |
|
| | def save_vocab(self, file_path): |
| | with open(file_path, "w") as file: |
| | for token, id in self.token_to_id.items(): |
| | file.write(f"{token}\t{id}\n") |
| |
|
| | def load_vocab(self, file_path): |
| | self.token_to_id = {} |
| | with open(file_path, "r") as file: |
| | for line in file.read().split("\n"): |
| | try: |
| | token, id = line.strip().split("\t") |
| | self.token_to_id[token] = int(id) |
| | except ValueError: |
| | |
| | |
| | pass |
| |
|
| | @staticmethod |
| | def attention_mask(encoded_sequence, mask_token_ids=[0]): |
| | mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int) |
| | |
| | |
| | return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
| |
|
| | def get_rarity_score(self, sequence): |
| | scores = np.zeros_like(sequence) |
| | for idx, token in enumerate(sequence): |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self._token_freqs == {}: |
| | self.make_token_freqs() |
| | if not self.id_to_token: |
| | self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
| | token_count = self._token_freqs.get(self.id_to_token[token.item()], 0) |
| | rarity_score = self.total_num_tokens / token_count if token_count > 0 else 0 |
| | scores[idx] = rarity_score |
| | |
| | return np.float32(np.median(scores)) |
| |
|
| | def get_entropy_score(self, sequence): |
| | if len(sequence) == 0: |
| | return 0.0 |
| |
|
| | unique, counts = np.unique(sequence, return_counts=True) |
| |
|
| | probs = counts / counts.sum() |
| | entropy = -np.sum(probs * np.log2(probs)) |
| |
|
| | if len(unique) > 1: |
| | entropy /= np.log2(len(unique)) |
| |
|
| | return np.float32(entropy) |
| |
|
| |
|
| | class DummySequentialDataManager: |
| | def __init__(self, root_dir, vocab_size=5000): |
| | print("init") |
| | self.root_dir = root_dir |
| | self.vocab_size = vocab_size |
| | with open(os.path.join(root_dir, "data/corpus_processed.txt"), "w+") as f: |
| | f.write("dummy") |
| |
|
| | def encode(self, text: str): |
| | return [list(range(50))] |
| |
|
| | def decode(self, ids): |
| | l = ids |
| | if isinstance(l, torch.Tensor): |
| | l = ids.tolist() |
| | if isinstance(l, int): |
| | l = [l] |
| |
|
| | return " ".join([str(id) for id in l]) |
| |
|
| | @staticmethod |
| | def attention_mask(encoded_sequence, mask_token_ids=[]): |
| | mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( |
| | encoded_sequence.device |
| | ) |
| | |
| | |
| | return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
| |
|
| |
|
| | class TextCorpusDataset(Dataset): |
| | def __init__( |
| | self, |
| | root_dir="./test-data", |
| | train=False, |
| | max_length=512, |
| | vocab_size=10000, |
| | IS_DUMMY=False, |
| | IS_CODE=False, |
| | IS_CUSTOM=False, |
| | sliding_window=False, |
| | stride=1, |
| | get_rarity_score=False, |
| | get_entropy_score=False, |
| | ): |
| | print(root_dir) |
| |
|
| | |
| | print("[TextCorpusDataset]") |
| | frame = inspect.currentframe() |
| | args, _, _, values = inspect.getargvalues(frame) |
| | print("Arguments passed:") |
| | for arg in args[1:]: |
| | print(f" {arg} = {values[arg]}") |
| |
|
| | self.root = root_dir |
| | self.sliding_window = sliding_window |
| | self.window_size = max_length |
| | self.stride = stride |
| | self.get_rarity_score = get_rarity_score |
| | self.get_entropy_score = get_entropy_score |
| |
|
| | if IS_DUMMY: |
| | self.manager = DummySequentialDataManager(root_dir=root_dir) |
| | elif IS_CODE: |
| | if IS_CUSTOM: |
| | self.manager = CodeCustomTokenizerManager(root_dir=root_dir) |
| | else: |
| | self.manager = CodeBPEModelManager( |
| | root_dir=root_dir, vocab_size=vocab_size |
| | ) |
| | else: |
| | self.manager = BPEModelManager(root_dir=root_dir, vocab_size=vocab_size) |
| |
|
| | self.max_length = max_length |
| | self.cache_file = os.path.join(root_dir, "encoded_chunked.pt") |
| | self.rarity_cache_file = os.path.join(root_dir, "rarity_scores.pt") |
| | self.entropy_cache_file = os.path.join(root_dir, "entropy_scores.pt") |
| |
|
| | start_t = time.time() |
| | if os.path.exists(self.cache_file): |
| | self.chunks = torch.load(self.cache_file, weights_only=True) |
| | if self.chunks.size(-1) != self.max_length: |
| | if ( |
| | input( |
| | "Attempting to fix and re-chunk data to correct length. Continue? [y/N]: " |
| | ) |
| | == "y" |
| | ): |
| | self._chunk_and_save(torch.flatten(self.chunks).tolist()) |
| | print("Re-chunked successfully!") |
| | else: |
| | print("Operation aborted.") |
| | else: |
| | with open( |
| | os.path.join(root_dir, "data/corpus_processed.txt"), |
| | "r", |
| | errors="ignore", |
| | ) as file: |
| | text = file.read() |
| | encoded = self.manager.encode(text) |
| |
|
| | self._chunk_and_save(encoded) |
| |
|
| | |
| | self._load_or_compute_scores() |
| |
|
| | end_t = time.time() |
| | print(f"Dataset loading took {end_t - start_t} seconds.") |
| |
|
| | |
| | self.chunks = self.chunks.to(DEVICE) |
| | if self.get_rarity_score: |
| | self.rarity_scores = self.rarity_scores.to(DEVICE) |
| | if self.get_entropy_score: |
| | self.entropy_scores = self.entropy_scores.to(DEVICE) |
| | self.dummy = torch.tensor([1], device=DEVICE) |
| |
|
| | def _chunk_and_save(self, encoded): |
| | chunked_data = [] |
| | if self.sliding_window: |
| | print("sliding!") |
| | for i in trange( |
| | 0, len(encoded) - self.window_size + 1, self.stride, leave=False |
| | ): |
| | chunked_data.append( |
| | torch.tensor(encoded[i : i + self.window_size], dtype=torch.int) |
| | ) |
| | else: |
| | for i in trange(0, len(encoded), self.max_length, leave=False): |
| | chunked_data.append( |
| | torch.tensor(encoded[i : i + self.max_length], dtype=torch.int) |
| | ) |
| |
|
| | |
| | padded_chunk = torch.zeros(self.max_length, dtype=torch.int) |
| | padded_chunk[: len(chunked_data[-1])] = chunked_data[-1] |
| | chunked_data[-1] = padded_chunk |
| |
|
| | self.chunks = torch.stack(chunked_data) |
| | torch.save(self.chunks, self.cache_file) |
| |
|
| | def _load_or_compute_scores(self): |
| | """Load cached scores or compute them if not available""" |
| | if self.get_rarity_score: |
| | if os.path.exists(self.rarity_cache_file): |
| | print("Loading cached rarity scores...") |
| | self.rarity_scores = torch.load(self.rarity_cache_file, weights_only=True) |
| | if len(self.rarity_scores) != len(self.chunks): |
| | print("Rarity cache size mismatch, recomputing...") |
| | self._compute_and_cache_rarity_scores() |
| | else: |
| | print("Computing rarity scores...") |
| | self._compute_and_cache_rarity_scores() |
| | |
| | if self.get_entropy_score: |
| | if os.path.exists(self.entropy_cache_file): |
| | print("Loading cached entropy scores...") |
| | self.entropy_scores = torch.load(self.entropy_cache_file, weights_only=True) |
| | if len(self.entropy_scores) != len(self.chunks): |
| | print("Entropy cache size mismatch, recomputing...") |
| | self._compute_and_cache_entropy_scores() |
| | else: |
| | print("Computing entropy scores...") |
| | self._compute_and_cache_entropy_scores() |
| |
|
| | def _compute_and_cache_rarity_scores(self): |
| | """Compute rarity scores for all chunks and cache them""" |
| | rarity_scores = [] |
| | print("Computing rarity scores for all chunks...") |
| | for i in trange(len(self.chunks), desc="Computing rarity scores"): |
| | score = self.manager.get_rarity_score(self.chunks[i]) |
| | rarity_scores.append(score) |
| | |
| | self.rarity_scores = torch.tensor(rarity_scores, dtype=torch.float32) |
| | torch.save(self.rarity_scores, self.rarity_cache_file) |
| | print(f"Cached rarity scores to {self.rarity_cache_file}") |
| |
|
| | def _compute_and_cache_entropy_scores(self): |
| | """Compute entropy scores for all chunks and cache them""" |
| | entropy_scores = [] |
| | print("Computing entropy scores for all chunks...") |
| | for i in trange(len(self.chunks), desc="Computing entropy scores"): |
| | score = self.manager.get_entropy_score(self.chunks[i]) |
| | entropy_scores.append(score) |
| | |
| | self.entropy_scores = torch.tensor(entropy_scores, dtype=torch.float32) |
| | torch.save(self.entropy_scores, self.entropy_cache_file) |
| | print(f"Cached entropy scores to {self.entropy_cache_file}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def __len__(self): |
| | return len(self.chunks) |
| |
|
| | def __getitem__( |
| | self, idx |
| | ): |
| | seq = self.chunks[idx] |
| | if self.get_rarity_score: |
| | return seq, self.rarity_scores[idx] |
| | if self.get_entropy_score: |
| | return seq, self.entropy_scores[idx] |
| | return seq, self.dummy |
| |
|
| |
|
| | class Datasplit_chunker(Dataset): |
| | def __init__(self, root, name, subset, slide=False, stride=1, length=512): |
| | super().__init__() |
| |
|
| | self.root = root |
| | if os.path.exists(os.path.join(root, f"encoded_chunked_{name}.pt")): |
| | self.items = torch.load( |
| | os.path.join(root, f"encoded_chunked_{name}.pt"), weights_only=True |
| | ) |
| |
|
| | else: |
| | self.items = torch.cat([subset.dataset[idx][0] for idx in subset.indices]) |
| |
|
| | if slide: |
| | self.items = self._sliding_window( |
| | self.items, window_size=length, stride=stride |
| | ) |
| |
|
| | torch.save(self.items, os.path.join(root, f"encoded_chunked_{name}.pt")) |
| | print("saved!") |
| | self.chunks = self.items |
| | self.dummy = torch.tensor([1], device=DEVICE) |
| |
|
| | def _sliding_window(self, sequence, window_size, stride): |
| | num_windows = (len(sequence) - window_size) // stride + 1 |
| | windows = torch.as_strided( |
| | sequence, size=(num_windows, window_size), stride=(stride, 1) |
| | ) |
| | return windows |
| |
|
| | def __len__(self): |
| | return len(self.items) |
| |
|
| | def __getitem__(self, idx): |
| | return self.chunks[idx], self.dummy |
| |
|
| |
|
| | |
| | dataset = TextCorpusDataset( |
| | root_dir=os.path.expanduser( |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | "~/torch_datasets/github-python/mega_licensed_corpus" |
| | ), |
| | vocab_size=33819, |
| | IS_CODE=True, |
| | IS_CUSTOM=True, |
| | |
| | max_length=256, |
| | sliding_window=False, |
| | stride=10, |
| | get_rarity_score=True, |
| | ) |
| |
|
| | dset_size = int(len(dataset)) |
| | train_size = int(0.8 * dset_size) |
| | test_size = int(dset_size - train_size) |
| | if test_size == 2: |
| | print("alert! test size is 2 or whatever. Change this back please.") |
| |
|
| | torch.manual_seed(3407) |
| |
|
| | train_dataset, test_dataset, _ = random_split( |
| | dataset, [train_size, test_size, len(dataset) - train_size - test_size] |
| | ) |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def get_train_dataset(): |
| | return train_dataset |
| |
|
| |
|
| | def get_test_dataset(): |
| |
|
| | return test_dataset |
| |
|
| |
|
| | def get_dataloader(dataset, batch_size=64): |
| |
|
| | return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| |
|
| |
|
| | def fromDataset(dataset): |
| | dset_size = int(len(dataset)) |
| | train_size = int(0.8 * dset_size) |
| | test_size = int(dset_size - train_size) |
| | if test_size == 2: |
| | print("alert! test size is 2 or whatever. Change this back please.") |
| |
|
| | torch.manual_seed(3407) |
| |
|
| | train_dataset, test_dataset, _ = random_split( |
| | dataset, [train_size, test_size, len(dataset) - train_size - test_size] |
| | ) |
| |
|
| | return train_dataset, test_dataset |
| |
|
| |
|
| | if __name__ == "__main__": |
| | d = get_train_dataset() |
| | print("Number of samples: ", len(d)) |
| | for a, b in d: |
| | |
| | manager = dataset.manager |
| | print(a) |
| | print(manager.decode(a)) |
| | |
| | print("--- sep batch --- ") |
| |
|
| | print(f"Number of tokens used: {len(dataset.manager.token_to_id)}") |
| | break |
| |
|