Spaces:
Runtime error
Runtime error
| import faiss | |
| import numpy as np | |
| import torch | |
| from typing import Dict, Tuple, List, NamedTuple | |
| import os | |
| import pickle | |
| import yaml | |
| from transformers import AutoModelForCausalLM | |
| class WeightInfo(NamedTuple): | |
| """ | |
| A named tuple containing metadata about a weight matrix. | |
| Attributes: | |
| model_name: Name or identifier of the model | |
| param_name: Name of the parameter in the model's state dict | |
| dimensions: Tuple containing the shape of the weight matrix (d1, d2) | |
| """ | |
| model_name: str | |
| param_name: str | |
| dimensions: Tuple[int, int] | |
| class CSWSearch: | |
| """ | |
| CSWSearch (Cosine Similarity of Weights Search) using FAISS for efficient similarity search. | |
| This class enables fast indexing and retrieval of similar weight matrices across models, | |
| organizing weight matrices by their dimensions to ensure comparable searches. | |
| """ | |
| def __init__(self): | |
| # Keep track of what each index position corresponds to | |
| self.metadata: Dict[Tuple[int, int], List[WeightInfo]] = {} | |
| # Track dimensions and index file locations | |
| self.index_files: Dict[Tuple[int, int], str] = {} | |
| # Directory where indices are stored | |
| self.index_dir: str = "indexes" | |
| # Currently loaded index | |
| self.current_index: Tuple[Tuple[int, int], faiss.Index] = None | |
| def add_weight_matrix( | |
| self, model_name: str, param_name: str, weight_matrix: np.ndarray | |
| ) -> None: | |
| """ | |
| Add a weight matrix to the appropriate index based on its dimensions. | |
| Args: | |
| model_name: Name or identifier of the model | |
| param_name: Name of the parameter in the model's state dict | |
| weight_matrix: The weight matrix tensor to index | |
| Returns: | |
| None | |
| """ | |
| print(f"Adding {model_name} {param_name}") | |
| d1, d2 = weight_matrix.shape | |
| dim_key = (d1, d2) | |
| # First time seeing this dimension combination | |
| if dim_key not in self.index_files: | |
| self.metadata[dim_key] = [] | |
| self.index_files[dim_key] = f"index_{d1}x{d2}.index" | |
| # Load the appropriate index | |
| index = self._load_index(dim_key) | |
| # Flatten matrix in row-major order and normalize | |
| flat_weights = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy()) | |
| faiss.normalize_L2(flat_weights) # for cosine similarity | |
| # Add to appropriate index | |
| index.add(flat_weights) | |
| # Store metadata | |
| self.metadata[dim_key].append(WeightInfo(model_name, param_name, (d1, d2))) | |
| # Save the updated index | |
| self._save_index(dim_key, index) | |
| def find_similar_weights( | |
| self, model_name: str, weight_matrix: np.ndarray, k: int = 5 | |
| ) -> List[Tuple[WeightInfo, float]]: | |
| """ | |
| Find similar weight matrices with matching dimensions. | |
| Searches for weight matrices most similar to the provided one, | |
| but only among those with the same dimensions. | |
| Args: | |
| model_name: Name or identifier of the model (used to exclude self-matches) | |
| weight_matrix: The weight matrix tensor to search for | |
| k: Number of similar matrices to return (default: 5) | |
| Returns: | |
| List of tuples containing (WeightInfo, similarity_score) | |
| Raises: | |
| ValueError: If no weight matrices with matching dimensions are found | |
| """ | |
| d1, d2 = weight_matrix.shape | |
| dim_key = (d1, d2) | |
| if dim_key not in self.index_files: | |
| raise ValueError(f"No weight matrices found with dimensions {dim_key}") | |
| # Load the appropriate index | |
| index = self._load_index(dim_key) | |
| # Prepare query in same way as stored matrices | |
| query = np.array(weight_matrix.to(dtype=torch.float32).reshape(1, -1).numpy()) | |
| faiss.normalize_L2(query) | |
| # Search | |
| distances, indices = index.search(query, k + 1) # +1 for self-match | |
| # Format results (excluding self-match) | |
| results = [] | |
| for idx, sim in zip(indices[0], distances[0]): | |
| info = self.metadata[dim_key][idx] | |
| if info.model_name != model_name: # Skip self-match | |
| results.append((info, float(sim))) | |
| return results[:k] | |
| def _load_index(self, dim_key: Tuple[int, int]) -> faiss.Index: | |
| """ | |
| Load or create the FAISS index for a specific dimension. | |
| Args: | |
| dim_key: Tuple of dimensions (d1, d2) | |
| Returns: | |
| faiss.Index: The loaded or newly created index | |
| """ | |
| if self.current_index and self.current_index[0] == dim_key: | |
| return self.current_index[1] | |
| d1, d2 = dim_key | |
| index_path = os.path.join(self.index_dir, self.index_files[dim_key]) | |
| if os.path.exists(index_path): | |
| try: | |
| index = faiss.read_index(index_path) | |
| except RuntimeError: | |
| print(f"Error reading index file {index_path}. Creating a new index.") | |
| index = faiss.IndexFlatIP(d1 * d2) | |
| else: | |
| print(f"Index file {index_path} not found. Creating a new index.") | |
| index = faiss.IndexFlatIP(d1 * d2) | |
| self.current_index = (dim_key, index) | |
| return index | |
| def _save_index(self, dim_key: Tuple[int, int], index: faiss.Index): | |
| """ | |
| Save the index for the given dimensions to disk. | |
| Args: | |
| dim_key: Tuple of dimensions (d1, d2) | |
| index: The FAISS index to save | |
| Returns: | |
| None | |
| """ | |
| index_path = os.path.join(self.index_dir, self.index_files[dim_key]) | |
| faiss.write_index(index, index_path) | |
| def save(self, directory: str): | |
| """ | |
| Save the entire search system (metadata and indexes) to disk. | |
| Args: | |
| directory: Directory where indices and metadata will be stored | |
| Returns: | |
| None | |
| """ | |
| self.index_dir = directory | |
| os.makedirs(directory, exist_ok=True) | |
| if self.current_index: | |
| self._save_index(self.current_index[0], self.current_index[1]) | |
| metadata_path = os.path.join(directory, "metadata.pkl") | |
| with open(metadata_path, "wb") as f: | |
| pickle.dump(self.metadata, f) | |
| index_files_path = os.path.join(directory, "index_files.pkl") | |
| with open(index_files_path, "wb") as f: | |
| pickle.dump(self.index_files, f) | |
| def load(cls, directory: str): | |
| """ | |
| Load a previously saved search system from disk. | |
| Args: | |
| directory: Directory where indices and metadata are stored | |
| Returns: | |
| CSWSearch: The loaded search system | |
| """ | |
| csw_search = cls() | |
| csw_search.index_dir = directory | |
| metadata_path = os.path.join(directory, "metadata.pkl") | |
| with open(metadata_path, "rb") as f: | |
| csw_search.metadata = pickle.load(f) | |
| index_files_path = os.path.join(directory, "index_files.pkl") | |
| with open(index_files_path, "rb") as f: | |
| csw_search.index_files = pickle.load(f) | |
| return csw_search | |
| csw = CSWSearch() | |
| def add_params(model_list): | |
| """ | |
| Index weight matrices from a list of HuggingFace model IDs. | |
| Loads each model, extracts its parameters, and adds all 2D weight matrices | |
| to the CSWSearch index for later similarity search. | |
| Args: | |
| model_list: List of HuggingFace model IDs to index | |
| Returns: | |
| None: Updates the global csw search index | |
| """ | |
| for model_id in model_list: | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) | |
| weights = model.state_dict() | |
| params = list(weights.keys()) | |
| for param in params: | |
| # Skip 1D tensors (like bias terms or layer norms) | |
| if len(weights[param].shape) == 1: | |
| continue | |
| csw.add_weight_matrix(model_id, param_name=param, weight_matrix=weights[param]) | |
| def get_similar_param(param, k=5): | |
| """ | |
| Find similar parameters to the given weight matrix across indexed models. | |
| Args: | |
| param: Weight matrix tensor to search for | |
| k: Number of similar matrices to return (default: 5) | |
| Returns: | |
| List of tuples containing (WeightInfo, similarity_score) | |
| """ | |
| return csw.find_similar_weights("--", param, k=k) | |
| def main(): | |
| # Model list to add from yaml | |
| model_list = yaml.safe_load(open("config/llama7b.yaml", "r")) | |
| add_params(model_list) | |
| csw.save("indexes") | |
| # Weight matrix to search for | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16 | |
| ) | |
| weights = model.state_dict() | |
| attn_name = "model.layers.0.self_attn.o_proj.weight" | |
| print(get_similar_param(weights[attn_name])) | |
| return | |
| if __name__ == "__main__": | |
| main() | |