Spaces:
Running
Running
| import os | |
| import json | |
| import numpy as np | |
| from FlagEmbedding import BGEM3FlagModel | |
| # Determine the directory of the script to load files relative to it | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # --- Configuration --- | |
| # IMPORTANT: Adjust MODEL_PATH to your model's actual local path. | |
| MODEL_PATH = '../../../../Downloads/bge-m3' | |
| # Path to the input JSON file for GA resolutions. | |
| GA_RESOLUTIONS_JSON_PATH = os.path.join(script_dir, '..', '..', 'parsed_ga_resolutions.json') | |
| # Output directory for the generated files. | |
| OUTPUT_DIR = os.path.join(script_dir, '..', '..') | |
| # --- Output and Cache File Paths --- | |
| DENSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_semantic_bge-m3.npy') | |
| SPARSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_loose_bge-m3.npy') | |
| MANIFEST_PATH = os.path.join(script_dir, 'embeddings_manifest.json') # New manifest file | |
| # --- Main Embedding Function --- | |
| def encode_ga_resolutions_with_caching(): | |
| # 1. --- Load the source of truth: all resolutions --- | |
| print(f"Loading all GA resolutions from: {GA_RESOLUTIONS_JSON_PATH}") | |
| try: | |
| with open(GA_RESOLUTIONS_JSON_PATH, 'r', encoding='utf-8') as file: | |
| all_resolutions_data = json.load(file) | |
| # Filter out resolutions without a valid body | |
| all_resolutions_data = [ | |
| r for r in all_resolutions_data if 'id' in r and 'body' in r and r['body'].strip() | |
| ] | |
| except (FileNotFoundError, json.JSONDecodeError, Exception) as e: | |
| print(f"Fatal Error: Could not load or parse the source resolutions file. Cannot proceed. Error: {e}") | |
| return | |
| # 2. --- Load existing cache (manifest and embeddings) --- | |
| cached_manifest = {} | |
| old_dense_embeddings = None | |
| old_sparse_embeddings = None | |
| if os.path.exists(MANIFEST_PATH) and os.path.exists(DENSE_OUTPUT_PATH) and os.path.exists(SPARSE_OUTPUT_PATH): | |
| print("Found existing cache. Loading manifest and embeddings.") | |
| try: | |
| with open(MANIFEST_PATH, 'r', encoding='utf-8') as f: | |
| cached_manifest = json.load(f) | |
| # Convert string keys from JSON back to integers if necessary | |
| cached_manifest = {int(k): v for k, v in cached_manifest.items()} | |
| old_dense_embeddings = np.load(DENSE_OUTPUT_PATH) | |
| old_sparse_embeddings = np.load(SPARSE_OUTPUT_PATH, allow_pickle=True) | |
| print(f"Successfully loaded cache for {len(cached_manifest)} resolutions.") | |
| except Exception as e: | |
| print(f"Warning: Could not load cache files correctly: {e}. Re-embedding all resolutions.") | |
| cached_manifest = {} # Reset if cache is corrupt | |
| else: | |
| print("No existing cache found. Will generate embeddings for all resolutions.") | |
| # 3. --- Identify new resolutions to be encoded --- | |
| all_res_ids = {r['id'] for r in all_resolutions_data} | |
| cached_res_ids = set(cached_manifest.keys()) | |
| new_res_ids = all_res_ids - cached_res_ids | |
| if not new_res_ids: | |
| print("All resolutions are already embedded. Nothing to do. Exiting.") | |
| return | |
| print(f"Found {len(new_res_ids)} new resolutions to embed.") | |
| resolutions_to_encode = [r for r in all_resolutions_data if r['id'] in new_res_ids] | |
| # Sort by ID to ensure a consistent order | |
| resolutions_to_encode.sort(key=lambda x: x['id']) | |
| new_texts = [r['body'] for r in resolutions_to_encode] | |
| # 4. --- Initialize model and encode ONLY the new data --- | |
| print("Initializing BGEM3FlagModel...") | |
| try: | |
| model = BGEM3FlagModel(MODEL_PATH, use_fp16=True) | |
| print("Model loaded.") | |
| except Exception as e: | |
| print(f"Error loading model from {MODEL_PATH}: {e}") | |
| return | |
| print(f"Encoding {len(new_texts)} new resolutions (dense, sparse)...") | |
| try: | |
| new_embeddings = model.encode(new_texts, | |
| batch_size=8, | |
| max_length=8192, | |
| return_dense=True, | |
| return_sparse=True, | |
| return_colbert_vecs=False) | |
| except Exception as e: | |
| print(f"An error occurred during embedding generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # 5. --- Combine old and new embeddings --- | |
| new_dense_vecs = new_embeddings['dense_vecs'] | |
| new_sparse_list = new_embeddings['lexical_weights'] | |
| new_sparse_vecs = np.array(new_sparse_list, dtype=object) | |
| if old_dense_embeddings is not None and old_sparse_embeddings is not None: | |
| print("Combining new embeddings with cached ones...") | |
| combined_dense_embeddings = np.vstack([old_dense_embeddings, new_dense_vecs]) | |
| combined_sparse_embeddings = np.concatenate([old_sparse_embeddings, new_sparse_vecs]) | |
| else: | |
| # This branch is for the first run when no cache exists | |
| combined_dense_embeddings = new_dense_vecs | |
| combined_sparse_embeddings = new_sparse_vecs | |
| # 6. --- Update manifest and save everything --- | |
| print("Updating manifest file...") | |
| start_index = len(cached_manifest) | |
| updated_manifest = cached_manifest.copy() | |
| for i, res in enumerate(resolutions_to_encode): | |
| updated_manifest[res['id']] = start_index + i | |
| try: | |
| # Ensure output directory exists | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Save combined embeddings | |
| np.save(DENSE_OUTPUT_PATH, combined_dense_embeddings) | |
| print(f"Saved combined semantic embeddings to {DENSE_OUTPUT_PATH} (Shape: {combined_dense_embeddings.shape})") | |
| np.save(SPARSE_OUTPUT_PATH, combined_sparse_embeddings, allow_pickle=True) | |
| print( | |
| f"Saved combined loose embeddings to {SPARSE_OUTPUT_PATH} (Total objects: {len(combined_sparse_embeddings)})") | |
| # Save the updated manifest | |
| with open(MANIFEST_PATH, 'w', encoding='utf-8') as f: | |
| json.dump(updated_manifest, f, indent=2) | |
| print(f"Saved updated manifest to {MANIFEST_PATH}") | |
| print("\nGA Resolution embedding process complete!") | |
| except Exception as e: | |
| print(f"An error occurred while saving the files: {e}") | |
| # Call the function to start the embedding process | |
| if __name__ == "__main__": | |
| encode_ga_resolutions_with_caching() |