Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import faiss | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Iterable | |
| from utils.f import memoize | |
| from transformers import AutoConfig | |
| def get_config(model_name): | |
| return AutoConfig.from_pretrained(model_name) | |
| FAISS_LAYER_PATTERN = 'layer_*.faiss' | |
| LAYER_TEMPLATE = 'layer_{:02d}.faiss' | |
| def create_mask(head_size:int , n_heads:int, selected_heads:Iterable[int]): | |
| """Create a masked vector of size (head_size * n_heads), where 0 indicates we don't care about the contribution of that head 1 indicates that we do care | |
| Parameters: | |
| ----------- | |
| head_size: Hidden dimension of the heads | |
| n_heads: Number of heads the model has | |
| selected_heads: Which heads we don't want to zero out | |
| """ | |
| mask = np.zeros(n_heads) | |
| for h in selected_heads: | |
| mask[int(h)] = 1 | |
| return np.repeat(mask, head_size) | |
| class Indexes: | |
| """Wrapper around the faiss indices to make searching for a vector simpler and faster. | |
| Assumes there are files in the folder matching the pattern input | |
| """ | |
| def __init__(self, folder, pattern=FAISS_LAYER_PATTERN): | |
| self.base_dir = Path(folder) | |
| self.n_layers = len(list(self.base_dir.glob(pattern))) - 1 # Subtract final output | |
| self.indexes = [None] * (self.n_layers + 1) # Initialize empty list, adding 1 for input | |
| self.pattern = pattern | |
| self.__init_indexes() | |
| # Extract model name from folder hierarchy | |
| self.model_name = self.base_dir.parent.parent.stem | |
| self.config = get_config(self.model_name) | |
| self.nheads = self.config.num_attention_heads | |
| self.hidden_size = self.config.hidden_size | |
| assert (self.hidden_size % self.nheads) == 0, "Number of heads does not divide cleanly into the hidden size. Aborting" | |
| self.head_size = int(self.config.hidden_size / self.nheads) | |
| def __getitem__(self, v): | |
| """Slices not allowed, but index only""" | |
| return self.indexes[v] | |
| def __init_indexes(self): | |
| for fname in self.base_dir.glob(self.pattern): | |
| print(fname) | |
| idx = fname.stem.split('_')[-1] | |
| self.indexes[int(idx)] = faiss.read_index(str(fname)) | |
| def search(self, layer, query, k): | |
| """Search a given layer for the query vector. Return k results""" | |
| return self[layer].search(query, k) | |
| class ContextIndexes(Indexes): | |
| """Special index enabling masking of particular heads before searching""" | |
| def __init__(self, folder, pattern=FAISS_LAYER_PATTERN): | |
| super().__init__(folder, pattern) | |
| self.head_mask = partial(create_mask, self.head_size, self.nheads) | |
| # Int -> [Int] -> np.Array -> Int -> (np.Array(), ) | |
| def search(self, layer:int, heads:list, query:np.ndarray, k:int): | |
| """Search the embeddings for the context layer, masking by selected heads""" | |
| assert max(heads) < self.nheads, "max of selected heads must be lest than nheads. Are you indexing by 1 instead of 0?" | |
| assert min(heads) >= 0, "What is a negative head?" | |
| unique_heads = list(set(heads)) | |
| mask_vector = self.head_mask(unique_heads) | |
| mask_vector = mask_vector.reshape(query.shape) | |
| new_query = (query * mask_vector).astype(np.float32) | |
| return self[layer].search(new_query, k) | |