| import os |
| os.environ['OMP_NUM_THREADS'] = '1' |
| import faiss |
| from sentence_transformers import SentenceTransformer |
| import numpy as np |
| import pickle |
| import json |
| from tqdm import tqdm |
|
|
| |
| MODEL_DATA_DIR = "model_data_json" |
| INDEX_FILE = "index.faiss" |
| MAP_FILE = "index_to_metadata.pkl" |
| EMBEDDING_MODEL = 'all-mpnet-base-v2' |
| ENCODE_BATCH_SIZE = 32 |
| |
| COMMON_EXCLUDED_TAGS = {'transformers'} |
| EXCLUDED_TAG_PREFIXES = ('arxiv:', 'base_model:', 'dataset:', 'diffusers:', 'license:') |
| MODEL_EXPLANATION_KEY = "model_explanation_gemini" |
| |
|
|
| def load_model_data(directory): |
| """Loads model data, filters tags (by length, common words, prefixes), and combines relevant info for indexing.""" |
| all_texts = [] |
| all_metadata = [] |
| print(f"Loading model data from JSON files in: {directory}") |
| if not os.path.isdir(directory): |
| print(f"Error: Directory not found: {directory}") |
| return [], [] |
|
|
| filenames = [f for f in os.listdir(directory) if f.endswith(".json")] |
| for filename in tqdm(filenames, desc="Reading JSON files"): |
| filepath = os.path.join(directory, filename) |
| try: |
| with open(filepath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| if 'description' in data and 'model_id' in data: |
| description = data['description'] |
| model_id = data['model_id'] |
| if description: |
| original_tags = data.get('tags', []) |
| |
| filtered_tags = [ |
| str_tag for tag in original_tags |
| if ( |
| tag and isinstance(tag, str) and |
| len(tag) > 3 and |
| (str_tag := str(tag)).lower() not in COMMON_EXCLUDED_TAGS and |
| not str_tag.lower().startswith(EXCLUDED_TAG_PREFIXES) |
| ) |
| ] |
| tag_string = " ".join(filtered_tags) |
| explanation = data.get(MODEL_EXPLANATION_KEY) |
|
|
| |
| release_year = data.get('release_year') |
| parameter_count = data.get('parameter_count') |
| is_fine_tuned = data.get('is_fine_tuned', False) |
| category = data.get('category', 'Other') |
| model_family = data.get('model_family') |
|
|
| |
| text_parts = [] |
| |
| if explanation and isinstance(explanation, str): |
| text_parts.append(f"Summary: {explanation}") |
| text_parts.append(f"Summary: {explanation}") |
| |
| text_parts.append(f"Model: {model_id}") |
| |
| if tag_string: |
| text_parts.append(f"Tags: {tag_string}") |
| |
| if category: |
| text_parts.append(f"Category: {category}") |
| if model_family: |
| text_parts.append(f"Family: {model_family}") |
| if parameter_count: |
| text_parts.append(f"Parameters: {parameter_count}") |
| if release_year: |
| text_parts.append(f"Year: {release_year}") |
| if is_fine_tuned: |
| text_parts.append("Fine-tuned model") |
| |
| text_parts.append(f"Description: {description}") |
|
|
| combined_text = " ".join(text_parts).strip() |
| |
|
|
| all_texts.append(combined_text) |
| |
| metadata_entry = { |
| "model_id": model_id, |
| "tags": original_tags, |
| "downloads": data.get('downloads', 0) |
| } |
| if explanation and isinstance(explanation, str): |
| metadata_entry[MODEL_EXPLANATION_KEY] = explanation |
| |
| |
| if release_year: |
| metadata_entry["release_year"] = release_year |
| if parameter_count: |
| metadata_entry["parameter_count"] = parameter_count |
| if is_fine_tuned is not None: |
| metadata_entry["is_fine_tuned"] = is_fine_tuned |
| if category: |
| metadata_entry["category"] = category |
| if model_family: |
| metadata_entry["model_family"] = model_family |
| |
| all_metadata.append(metadata_entry) |
| else: |
| print(f"Warning: Skipping {filename}, missing 'description' or 'model_id' key.") |
| except json.JSONDecodeError: |
| print(f"Warning: Skipping {filename}, invalid JSON.") |
| except Exception as e: |
| print(f"Warning: Could not read or process {filename}: {e}") |
|
|
| print(f"Loaded data for {len(all_texts)} models with valid descriptions after tag filtering.") |
| return all_texts, all_metadata |
|
|
| def build_and_save_index(texts_to_index, metadata_list): |
| """Builds and saves the FAISS index and metadata mapping based on combined text.""" |
| if not texts_to_index: |
| print("No text data to index.") |
| return |
|
|
| print(f"Loading sentence transformer model: {EMBEDDING_MODEL}") |
| |
| |
| model = SentenceTransformer(EMBEDDING_MODEL) |
|
|
| print(f"Generating embeddings for combined text in batches of {ENCODE_BATCH_SIZE}...") |
| all_embeddings = [] |
| for i in tqdm(range(0, len(texts_to_index), ENCODE_BATCH_SIZE), desc="Encoding batches"): |
| batch = texts_to_index[i:i+ENCODE_BATCH_SIZE] |
| batch_embeddings = model.encode(batch, convert_to_numpy=True) |
| all_embeddings.append(batch_embeddings) |
|
|
| if not all_embeddings: |
| print("No embeddings generated. Cannot build index.") |
| return |
|
|
| embeddings = np.vstack(all_embeddings) |
|
|
| |
| embeddings = embeddings.astype('float32') |
|
|
| |
| print("Building FAISS index...") |
| dimension = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dimension) |
| index.add(embeddings) |
| print(f"FAISS index built with {index.ntotal} vectors.") |
|
|
| |
| faiss.write_index(index, INDEX_FILE) |
| print(f"FAISS index saved to: {INDEX_FILE}") |
|
|
| |
| index_to_metadata = {i: metadata for i, metadata in enumerate(metadata_list)} |
| with open(MAP_FILE, 'wb') as f: |
| pickle.dump(index_to_metadata, f) |
| print(f"Index-to-Metadata mapping saved to: {MAP_FILE}") |
|
|
| if __name__ == "__main__": |
| combined_texts, metadata_list = load_model_data(MODEL_DATA_DIR) |
| build_and_save_index(combined_texts, metadata_list) |
| print("\nIndex building complete.") |