import os import json import numpy as np from FlagEmbedding import BGEM3FlagModel # Determine the directory of the script to load files relative to it script_dir = os.path.dirname(os.path.abspath(__file__)) # --- Configuration --- # IMPORTANT: Adjust MODEL_PATH to your model's actual local path. MODEL_PATH = '../../../../Downloads/bge-m3' # Path to the input JSON file for GA resolutions. GA_RESOLUTIONS_JSON_PATH = os.path.join(script_dir, '..', '..', 'parsed_ga_resolutions.json') # Output directory for the generated files. OUTPUT_DIR = os.path.join(script_dir, '..', '..') # --- Output and Cache File Paths --- DENSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_semantic_bge-m3.npy') SPARSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_loose_bge-m3.npy') MANIFEST_PATH = os.path.join(script_dir, 'embeddings_manifest.json') # New manifest file # --- Main Embedding Function --- def encode_ga_resolutions_with_caching(): # 1. --- Load the source of truth: all resolutions --- print(f"Loading all GA resolutions from: {GA_RESOLUTIONS_JSON_PATH}") try: with open(GA_RESOLUTIONS_JSON_PATH, 'r', encoding='utf-8') as file: all_resolutions_data = json.load(file) # Filter out resolutions without a valid body all_resolutions_data = [ r for r in all_resolutions_data if 'id' in r and 'body' in r and r['body'].strip() ] except (FileNotFoundError, json.JSONDecodeError, Exception) as e: print(f"Fatal Error: Could not load or parse the source resolutions file. Cannot proceed. Error: {e}") return # 2. --- Load existing cache (manifest and embeddings) --- cached_manifest = {} old_dense_embeddings = None old_sparse_embeddings = None if os.path.exists(MANIFEST_PATH) and os.path.exists(DENSE_OUTPUT_PATH) and os.path.exists(SPARSE_OUTPUT_PATH): print("Found existing cache. Loading manifest and embeddings.") try: with open(MANIFEST_PATH, 'r', encoding='utf-8') as f: cached_manifest = json.load(f) # Convert string keys from JSON back to integers if necessary cached_manifest = {int(k): v for k, v in cached_manifest.items()} old_dense_embeddings = np.load(DENSE_OUTPUT_PATH) old_sparse_embeddings = np.load(SPARSE_OUTPUT_PATH, allow_pickle=True) print(f"Successfully loaded cache for {len(cached_manifest)} resolutions.") except Exception as e: print(f"Warning: Could not load cache files correctly: {e}. Re-embedding all resolutions.") cached_manifest = {} # Reset if cache is corrupt else: print("No existing cache found. Will generate embeddings for all resolutions.") # 3. --- Identify new resolutions to be encoded --- all_res_ids = {r['id'] for r in all_resolutions_data} cached_res_ids = set(cached_manifest.keys()) new_res_ids = all_res_ids - cached_res_ids if not new_res_ids: print("All resolutions are already embedded. Nothing to do. Exiting.") return print(f"Found {len(new_res_ids)} new resolutions to embed.") resolutions_to_encode = [r for r in all_resolutions_data if r['id'] in new_res_ids] # Sort by ID to ensure a consistent order resolutions_to_encode.sort(key=lambda x: x['id']) new_texts = [r['body'] for r in resolutions_to_encode] # 4. --- Initialize model and encode ONLY the new data --- print("Initializing BGEM3FlagModel...") try: model = BGEM3FlagModel(MODEL_PATH, use_fp16=True) print("Model loaded.") except Exception as e: print(f"Error loading model from {MODEL_PATH}: {e}") return print(f"Encoding {len(new_texts)} new resolutions (dense, sparse)...") try: new_embeddings = model.encode(new_texts, batch_size=8, max_length=8192, return_dense=True, return_sparse=True, return_colbert_vecs=False) except Exception as e: print(f"An error occurred during embedding generation: {e}") import traceback traceback.print_exc() return # 5. --- Combine old and new embeddings --- new_dense_vecs = new_embeddings['dense_vecs'] new_sparse_list = new_embeddings['lexical_weights'] new_sparse_vecs = np.array(new_sparse_list, dtype=object) if old_dense_embeddings is not None and old_sparse_embeddings is not None: print("Combining new embeddings with cached ones...") combined_dense_embeddings = np.vstack([old_dense_embeddings, new_dense_vecs]) combined_sparse_embeddings = np.concatenate([old_sparse_embeddings, new_sparse_vecs]) else: # This branch is for the first run when no cache exists combined_dense_embeddings = new_dense_vecs combined_sparse_embeddings = new_sparse_vecs # 6. --- Update manifest and save everything --- print("Updating manifest file...") start_index = len(cached_manifest) updated_manifest = cached_manifest.copy() for i, res in enumerate(resolutions_to_encode): updated_manifest[res['id']] = start_index + i try: # Ensure output directory exists os.makedirs(OUTPUT_DIR, exist_ok=True) # Save combined embeddings np.save(DENSE_OUTPUT_PATH, combined_dense_embeddings) print(f"Saved combined semantic embeddings to {DENSE_OUTPUT_PATH} (Shape: {combined_dense_embeddings.shape})") np.save(SPARSE_OUTPUT_PATH, combined_sparse_embeddings, allow_pickle=True) print( f"Saved combined loose embeddings to {SPARSE_OUTPUT_PATH} (Total objects: {len(combined_sparse_embeddings)})") # Save the updated manifest with open(MANIFEST_PATH, 'w', encoding='utf-8') as f: json.dump(updated_manifest, f, indent=2) print(f"Saved updated manifest to {MANIFEST_PATH}") print("\nGA Resolution embedding process complete!") except Exception as e: print(f"An error occurred while saving the files: {e}") # Call the function to start the embedding process if __name__ == "__main__": encode_ga_resolutions_with_caching()