| | from typing import Dict, Iterable, List, Optional, Tuple, Union |
| |
|
| | import collections |
| | import glob |
| | import json |
| | import hashlib |
| | import itertools |
| | import logging |
| | import multiprocessing |
| | import os |
| | import pickle |
| | import random |
| | import requests |
| | import sys |
| | import zipfile |
| |
|
| | import datasets |
| | import numpy as np |
| | import torch |
| | import tqdm |
| | import transformers |
| |
|
| | from cde.lib.dist import get_num_proc, get_rank |
| |
|
| |
|
| | def get_cde_cache_dir() -> str: |
| | script_directory = os.path.normpath( |
| | os.path.join( |
| | os.path.dirname(os.path.abspath(__file__)), |
| | os.pardir, os.pardir, |
| | ) |
| | ) |
| | return os.path.join(script_directory, "data") |
| |
|
| |
|
| | def get_cache_location_from_kwargs(**kwargs): |
| | cache_location = os.path.join( |
| | get_cde_cache_dir(), "cluster" |
| | ) |
| | os.makedirs(cache_location, exist_ok=True) |
| | return os.path.join(cache_location, md5_hash_kwargs(**kwargs)) |
| |
|
| |
|
| | def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]: |
| | qrels_idxs = collections.defaultdict(list) |
| | qrels_scores = collections.defaultdict(list) |
| | corpus_ids = np.array(corpus['_id']) |
| | skipped_qrels = 0 |
| |
|
| | for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | q_id = str(ex['query-id']) |
| | c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0] |
| | |
| | assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)" |
| | |
| | if len(c_idxs): |
| | qrels_idxs[q_id].append(c_idxs[0]) |
| | qrels_scores[q_id].append(ex['score']) |
| | else: |
| | skipped_qrels += 1 |
| | |
| | |
| | if skipped_qrels > 0: |
| | logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.') |
| | |
| | return qrels_idxs, qrels_scores |
| |
|
| |
|
| | def process_qrels( |
| | corpus: datasets.Dataset, qrels: datasets.Dataset, |
| | use_cache: bool = True |
| | ) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]: |
| | dataset_cache_file = '_'.join( |
| | (corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename']) |
| | ) |
| | cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p' |
| | os.makedirs(os.path.dirname(cache_file), exist_ok=True) |
| |
|
| | if not (use_cache and os.path.exists(cache_file)): |
| | qrels_idxs, qrels_scores = process_qrels_uncached( |
| | corpus=corpus, qrels=qrels |
| | ) |
| | if use_cache: |
| | pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb')) |
| | else: |
| | qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb')) |
| | |
| | return qrels_idxs, qrels_scores |
| |
|
| |
|
| | def strip_extension(filename: str) -> str: |
| | """Strips file extension. |
| | |
| | Ex: |
| | >> strip_extension('/root/dir/sub/file.ext') |
| | '/root/dir/sub/file' |
| | """ |
| | return os.path.splitext(filename)[0] |
| |
|
| |
|
| | def md5_hash(t: Tuple[str]) -> str: |
| | return hashlib.md5('__'.join(t).encode()).hexdigest() |
| |
|
| |
|
| | def md5_hash_kwargs(**kwargs) -> str: |
| | |
| | safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')} |
| | s = json.dumps(safe_kwargs, sort_keys=True) |
| | return hashlib.md5(s.encode()).hexdigest() |
| |
|
| | def download_url(url: str, save_path: str, chunk_size: int = 1024): |
| | """Download url with progress bar using tqdm |
| | https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads |
| | Args: |
| | url (str): downloadable url |
| | save_path (str): local path to save the downloaded file |
| | chunk_size (int, optional): chunking of files. Defaults to 1024. |
| | """ |
| | r = requests.get(url, stream=True) |
| | total = int(r.headers.get('Content-Length', 0)) |
| | with open(save_path, 'wb') as fd, tqdm.tqdm( |
| | desc=save_path, |
| | total=total, |
| | unit='iB', |
| | unit_scale=True, |
| | unit_divisor=chunk_size, |
| | ) as bar: |
| | for data in r.iter_content(chunk_size=chunk_size): |
| | size = fd.write(data) |
| | bar.update(size) |
| |
|
| |
|
| | def unzip(zip_file: str, out_dir: str): |
| | print("unzipping =>", zip_file) |
| | zip_ = zipfile.ZipFile(zip_file, "r") |
| | zip_.extractall(path=out_dir) |
| | zip_.close() |
| |
|
| |
|
| | def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str: |
| | os.makedirs(out_dir, exist_ok=True) |
| | dataset = url.split("/")[-1] |
| | zip_file = os.path.join(out_dir, dataset) |
| | |
| | if not os.path.isfile(zip_file): |
| | logging.info("Downloading {} ...".format(dataset)) |
| | download_url(url, zip_file, chunk_size) |
| | |
| | if not os.path.isdir(zip_file.replace(".zip", "")): |
| | logging.info("Unzipping {} ...".format(dataset)) |
| | unzip(zip_file, out_dir) |
| | |
| | return os.path.join(out_dir, dataset.replace(".zip", "")) |
| |
|
| |
|
| | def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable: |
| | if get_rank() == 0: |
| | return tqdm.tqdm(iterable, **kwargs) |
| | else: |
| | return iterable |
| |
|
| |
|
| | class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig): |
| | """We create a dummy configuration class that will just set properties |
| | based on whatever kwargs we pass in. |
| | |
| | When this class is initialized (see experiments.py) we pass in the |
| | union of all data, model, and training args, all of which should |
| | get saved to the config json. |
| | """ |
| |
|
| | def __init__(self, **kwargs): |
| | for key, value in kwargs.items(): |
| | try: |
| | json.dumps(value) |
| | setattr(self, key, value) |
| | except TypeError: |
| | |
| | continue |
| | super().__init__() |
| |
|
| |
|
| | def independent_crop( |
| | input_ids: torch.Tensor, pad_token_id: int, |
| | l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Returns two independent crops from input_ids. |
| | |
| | Assumes input_ids has a beginning and end token, like |
| | [101, ..., 102, 0, 0, 0]. |
| | |
| | Args: |
| | input_ids: tensor of IDs |
| | pad_token_id: ID of pad tokens in input_ids |
| | l1: length of span 1, cropped |
| | l2: length of span 2, cropped |
| | Returns: |
| | span1: first crop (of length l1) |
| | span2: second crop (of length l2) |
| | """ |
| | |
| | if (input_ids == pad_token_id).sum() == 0: |
| | N = len(input_ids) |
| | else: |
| | N = (input_ids == pad_token_id).int().argmax().item() |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | nl1 = min(N//2, l1) |
| | nl2 = min(N//2, l2) |
| |
|
| | s1_start = random.randint(1, N-nl1) |
| | s2_start = random.randint(1, N-nl2) |
| |
|
| | s1_idxs = itertools.chain( |
| | [0], range(s1_start, s1_start+nl1), [N-1] |
| | ) |
| | s1 = input_ids[torch.tensor(list(s1_idxs))] |
| | s2_idxs = itertools.chain( |
| | [0], range(s2_start, s2_start+nl2), [N-1] |
| | ) |
| | s2 = input_ids[torch.tensor(list(s2_idxs))] |
| | return (s1, s2) |
| |
|
| |
|
| | def load_dataset_tables( |
| | files: Iterable[str], num_workers: int = 16 |
| | ) -> Iterable[datasets.table.MemoryMappedTable]: |
| | import concurrent |
| | from multiprocessing import Pool |
| |
|
| | |
| | num_workers = min(32, len(files)) |
| |
|
| | use_threads = True |
| | if use_threads: |
| | pool_cls = concurrent.futures.ThreadPoolExecutor |
| | pool_kwargs = {"max_workers": num_workers} |
| | else: |
| | pool_cls = Pool |
| | pool_kwargs = {"processes": num_workers} |
| | |
| | with pool_cls(**pool_kwargs) as pool: |
| | if len(files) > 10: |
| | files = tqdm_if_main_worker( |
| | files, |
| | desc=f"Loading {len(files)} files with {num_workers} workers", |
| | total=len(files), |
| | colour="#ffbd88" |
| | ) |
| | |
| | result = list( |
| | pool.map(datasets.table.MemoryMappedTable.from_file, files) |
| | ) |
| | return result |
| |
|
| |
|
| | def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset: |
| | logging.info(f"fast_load_from_disk called with path:", cache_path) |
| | dataset_info_path = os.path.join(cache_path, "dataset_info.json") |
| | with open(dataset_info_path, encoding="utf-8") as dataset_info_file: |
| | dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file)) |
| |
|
| | dataset_state_path = os.path.join(cache_path, "state.json") |
| | with open(dataset_state_path, encoding="utf-8") as state_file: |
| | state = json.load(state_file) |
| |
|
| | files = glob.glob(os.path.join(cache_path, "data-*.arrow")) |
| | files = sorted(files) |
| | num_workers = get_num_proc() |
| | ds_tables = load_dataset_tables( |
| | files=files, |
| | num_workers=num_workers |
| | ) |
| | arrow_table = datasets.table.concat_tables(ds_tables) |
| |
|
| | split = state["_split"] |
| | split = datasets.splits.Split(split) if split is not None else split |
| |
|
| | |
| | return datasets.Dataset( |
| | arrow_table=arrow_table, |
| | info=dataset_info, |
| | split=split, |
| | fingerprint=state["_fingerprint"], |
| | ) |
| |
|
| |
|
| | def tokenize_dataset( |
| | dataset: datasets.Dataset, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | max_length: int, |
| | text_key: str, |
| | padding_strategy: str |
| | ) -> datasets.Dataset: |
| | def tokenize_text(ex: Dict) -> Dict: |
| | tt = tokenizer( |
| | ex[text_key], |
| | max_length=max_length, |
| | truncation=True, |
| | padding=padding_strategy, |
| | ) |
| | for k,v in tt.items(): |
| | ex[f"{text_key}_{k}"] = v |
| | ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]] |
| | return ex |
| |
|
| | |
| | vocab = tokenizer.vocab |
| | vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word])) |
| | vocab_hash = md5_hash(vocab_words) |
| |
|
| | data_fingerprint = '__'.join(( |
| | dataset._fingerprint, str(vocab_hash), str(max_length), |
| | text_key, padding_strategy |
| | )) |
| | data_fingerprint = md5_hash(data_fingerprint) |
| | dataset = dataset.map( |
| | tokenize_text, |
| | new_fingerprint=data_fingerprint, |
| | batched=True, |
| | load_from_cache_file=True, |
| | ) |
| | return dataset |
| |
|
| |
|
| | class TensorRunningAverages: |
| | _store_sum: Dict[str, torch.Tensor] |
| | _store_total: Dict[str, torch.Tensor] |
| |
|
| | def __init__(self): |
| | self._store_sum = {} |
| | self._store_total = {} |
| | |
| | def __iter__(self) -> Iterable[str]: |
| | return iter(self._store_sum.keys()) |
| |
|
| | def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None: |
| | if key not in self._store_sum: |
| | self.clear(key) |
| | if isinstance(val, torch.Tensor): |
| | val = val.item() |
| | self._store_sum[key] += val |
| | self._store_total[key] += 1 |
| |
|
| | def get(self, key: str) -> float: |
| | total = max(self._store_total.get(key).item(), 1.0) |
| | return (self._store_sum[key] / float(total)).item() or 0.0 |
| | |
| | def clear(self, key: str) -> None: |
| | self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32) |
| | self._store_total[key] = torch.tensor(0, dtype=torch.int32) |
| | |
| | def clear_all(self) -> None: |
| | for key in self._store_sum: |
| | self.clear(key) |
| |
|
| | def get_and_clear_all(self) -> Dict[str, float]: |
| | metrics = {} |
| | for key in self: |
| | metrics[key] = self.get(key) |
| | self.clear(key) |
| | return metrics |
| |
|
| | def load_embedder_and_tokenizer(name: str) -> Tuple[ |
| | transformers.PreTrainedModel, |
| | transformers.PreTrainedTokenizer |
| | ]: |
| | if name.startswith("nomic") or (name == "bert-base-uncased"): |
| | from cde.lib.nomic_bert import NomicBertModel |
| | if name.endswith("--from-scratch"): |
| | name = name.replace("--from-scratch", "") |
| | config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True) |
| | model = NomicBertModel._from_config(config) |
| | else: |
| | model = NomicBertModel.from_pretrained( |
| | name, add_pooling_layer=False |
| | ) |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | elif name in ["gtr-base", "gtr_base"]: |
| | model = transformers.AutoModel.from_pretrained( |
| | "sentence-transformers/gtr-t5-base" |
| | ).encoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "sentence-transformers/gtr-t5-base" |
| | ) |
| | elif name == "pile-t5-base-encoder": |
| | model = transformers.AutoModel.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ).encoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | elif name == "pile-t5-base-decoder": |
| | model = transformers.AutoModel.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ).decoder |
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | "EleutherAI/pile-t5-base" |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name): |
| | model = transformers.AutoModelForCausalLM.from_pretrained( |
| | name, |
| | |
| | attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa", |
| | low_cpu_mem_usage=True, |
| | |
| | ) |
| | model.padding_side = "right" |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.add_eos_token = True |
| | elif "Modern" in name: |
| | print("special loading for ModernBERT!") |
| | |
| | |
| | |
| | model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=False) |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | else: |
| | model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True) |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(name) |
| | return model, tokenizer |
| |
|
| |
|
| | def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str): |
| | key += "_" |
| | return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)} |
| |
|
| |
|
| | def count_cpus() -> int: |
| | try: |
| | return len(os.sched_getaffinity(0)) |
| | except AttributeError: |
| | return multiprocessing.cpu_count() |
| |
|
| |
|
| | def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]: |
| | all_indices = [] |
| | for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"): |
| | rand_perm = torch.randperm(len(batch_tensor), generator=g) |
| | batch_list = batch_tensor[rand_perm].tolist() |
| | all_indices.extend(batch_list) |
| | return all_indices |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def exit_if_running_or_finished_wandb( |
| | project_name: str, |
| | exp_group: str, exp_name: str |
| | ) -> None: |
| | print("Checking if experiment is already running...") |
| | import wandb |
| |
|
| | api = wandb.Api() |
| | running_runs = api.runs( |
| | path="cde-0", |
| | filters={ |
| | "display_name": exp_name, |
| | "state": {"$regex": "Running|Finished"}, |
| | "config.exp_group": exp_group, |
| | } |
| | ) |
| | print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.") |
| |
|
| | if len(running_runs) > 0: |
| | print("Exiting because experiment is already running or completed.") |
| | sys.exit(0) |
| | |
| |
|
| | HN_FILTER_TOKENIZER_MAP = { |
| | "nomic": "nomic-ai/nomic-embed-text-v1", |
| | "stella": "dunzhang/stella_en_400M_v5", |
| | "sbert": "sentence-transformers/all-MiniLM-L6-v2", |
| | "sentence_t5": "sentence-transformers/sentence-t5-base", |
| | "gte": "Alibaba-NLP/gte-large-en-v1.5", |
| | } |
| | def load_hn_filter_tokenizer(tokenizer_name: str) -> Optional[transformers.PreTrainedTokenizer]: |
| | if tokenizer_name in HN_FILTER_TOKENIZER_MAP: |
| | return transformers.AutoTokenizer.from_pretrained(HN_FILTER_TOKENIZER_MAP[tokenizer_name]) |
| | else: |
| | return None |
| |
|