Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import faiss | |
| import numpy as np | |
| from data_processing.corpus_data_wrapper import CorpusDataWrapper | |
| from data_processing.index_wrapper import LAYER_TEMPLATE | |
| import argparse | |
| # Get model from base_dir | |
| # Use that information to get the model's configuration | |
| # From this, get the special tokens associated with that model | |
| # Have flag to allow model's special tokens to be ignored | |
| # Test what items match 'bert-base-cased' | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-d", "--directory", help="Path to the directory that contains the 'embeddings' and 'headContext' folders") | |
| args = parser.parse_args() | |
| return args | |
| def train_indexes(ce:CorpusDataWrapper, stepsize=100, drop_null=True): | |
| """ | |
| Parameters: | |
| =========== | |
| - corpus_embedding: Wrapper around HDF5 file for easy access to data | |
| - stepsize: How many sentences to train with at once | |
| - drop_null: Don't index the embeddings of special tokens (e.g., [CLS] and [SEP]) whose spacy POS are null | |
| """ | |
| NUM_LAYERS = ce.n_layers # want to account for the input layer, which for attentions + contexts is all value 0 | |
| embedding_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)] | |
| context_indexes = [faiss.IndexFlatIP(ce.embedding_dim) for i in range(NUM_LAYERS)] | |
| for ix in range(0, len(ce), stepsize): | |
| cdata = ce[ix:ix+stepsize] | |
| if drop_null: | |
| embeddings = np.concatenate([c.zero_special_embeddings for c in cdata], axis=1) | |
| contexts = np.concatenate([c.zero_special_contexts for c in cdata], axis=1) | |
| else: | |
| embeddings = np.concatenate([c.embeddings for c in cdata], axis=1) | |
| contexts = np.concatenate([c.contexts for c in cdata], axis=1) | |
| for i in range(NUM_LAYERS): | |
| embedding_indexes[i].add(embeddings[i]) | |
| context_indexes[i].add(contexts[i]) | |
| return embedding_indexes, context_indexes | |
| def save_indexes(idxs, outdir, base_name=LAYER_TEMPLATE): | |
| """Save the faiss index into a file for each index in idxs""" | |
| base_dir = Path(outdir) | |
| if not base_dir.exists(): base_dir.mkdir(exist_ok=True, parents=True) | |
| out_name = str(base_dir / base_name) | |
| for i, idx in enumerate(idxs): | |
| name = out_name.format(i) | |
| print(f"Saving to {name}") | |
| faiss.write_index(idx, name) | |
| def main(basedir): | |
| base = Path(basedir) | |
| h5_fname = base / 'data.hdf5' | |
| corpus = CorpusDataWrapper(h5_fname) | |
| embedding_faiss, context_faiss = train_indexes(corpus) | |
| context_faiss_dir = base / "context_faiss" | |
| embedding_faiss_dir = base / "embedding_faiss" | |
| save_indexes(embedding_faiss, embedding_faiss_dir) | |
| save_indexes(context_faiss, context_faiss_dir) | |
| if __name__ == "__main__": | |
| # Creating the indices for both the context and embeddings | |
| args = parse_args() | |
| main(args.directory) | |