Spaces:
Running
Running
File size: 6,354 Bytes
3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 3a373f3 7392937 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import json
import numpy as np
from FlagEmbedding import BGEM3FlagModel
# Determine the directory of the script to load files relative to it
script_dir = os.path.dirname(os.path.abspath(__file__))
# --- Configuration ---
# IMPORTANT: Adjust MODEL_PATH to your model's actual local path.
MODEL_PATH = '../../../../Downloads/bge-m3'
# Path to the input JSON file for GA resolutions.
GA_RESOLUTIONS_JSON_PATH = os.path.join(script_dir, '..', '..', 'parsed_ga_resolutions.json')
# Output directory for the generated files.
OUTPUT_DIR = os.path.join(script_dir, '..', '..')
# --- Output and Cache File Paths ---
DENSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_semantic_bge-m3.npy')
SPARSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_loose_bge-m3.npy')
MANIFEST_PATH = os.path.join(script_dir, 'embeddings_manifest.json') # New manifest file
# --- Main Embedding Function ---
def encode_ga_resolutions_with_caching():
# 1. --- Load the source of truth: all resolutions ---
print(f"Loading all GA resolutions from: {GA_RESOLUTIONS_JSON_PATH}")
try:
with open(GA_RESOLUTIONS_JSON_PATH, 'r', encoding='utf-8') as file:
all_resolutions_data = json.load(file)
# Filter out resolutions without a valid body
all_resolutions_data = [
r for r in all_resolutions_data if 'id' in r and 'body' in r and r['body'].strip()
]
except (FileNotFoundError, json.JSONDecodeError, Exception) as e:
print(f"Fatal Error: Could not load or parse the source resolutions file. Cannot proceed. Error: {e}")
return
# 2. --- Load existing cache (manifest and embeddings) ---
cached_manifest = {}
old_dense_embeddings = None
old_sparse_embeddings = None
if os.path.exists(MANIFEST_PATH) and os.path.exists(DENSE_OUTPUT_PATH) and os.path.exists(SPARSE_OUTPUT_PATH):
print("Found existing cache. Loading manifest and embeddings.")
try:
with open(MANIFEST_PATH, 'r', encoding='utf-8') as f:
cached_manifest = json.load(f)
# Convert string keys from JSON back to integers if necessary
cached_manifest = {int(k): v for k, v in cached_manifest.items()}
old_dense_embeddings = np.load(DENSE_OUTPUT_PATH)
old_sparse_embeddings = np.load(SPARSE_OUTPUT_PATH, allow_pickle=True)
print(f"Successfully loaded cache for {len(cached_manifest)} resolutions.")
except Exception as e:
print(f"Warning: Could not load cache files correctly: {e}. Re-embedding all resolutions.")
cached_manifest = {} # Reset if cache is corrupt
else:
print("No existing cache found. Will generate embeddings for all resolutions.")
# 3. --- Identify new resolutions to be encoded ---
all_res_ids = {r['id'] for r in all_resolutions_data}
cached_res_ids = set(cached_manifest.keys())
new_res_ids = all_res_ids - cached_res_ids
if not new_res_ids:
print("All resolutions are already embedded. Nothing to do. Exiting.")
return
print(f"Found {len(new_res_ids)} new resolutions to embed.")
resolutions_to_encode = [r for r in all_resolutions_data if r['id'] in new_res_ids]
# Sort by ID to ensure a consistent order
resolutions_to_encode.sort(key=lambda x: x['id'])
new_texts = [r['body'] for r in resolutions_to_encode]
# 4. --- Initialize model and encode ONLY the new data ---
print("Initializing BGEM3FlagModel...")
try:
model = BGEM3FlagModel(MODEL_PATH, use_fp16=True)
print("Model loaded.")
except Exception as e:
print(f"Error loading model from {MODEL_PATH}: {e}")
return
print(f"Encoding {len(new_texts)} new resolutions (dense, sparse)...")
try:
new_embeddings = model.encode(new_texts,
batch_size=8,
max_length=8192,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False)
except Exception as e:
print(f"An error occurred during embedding generation: {e}")
import traceback
traceback.print_exc()
return
# 5. --- Combine old and new embeddings ---
new_dense_vecs = new_embeddings['dense_vecs']
new_sparse_list = new_embeddings['lexical_weights']
new_sparse_vecs = np.array(new_sparse_list, dtype=object)
if old_dense_embeddings is not None and old_sparse_embeddings is not None:
print("Combining new embeddings with cached ones...")
combined_dense_embeddings = np.vstack([old_dense_embeddings, new_dense_vecs])
combined_sparse_embeddings = np.concatenate([old_sparse_embeddings, new_sparse_vecs])
else:
# This branch is for the first run when no cache exists
combined_dense_embeddings = new_dense_vecs
combined_sparse_embeddings = new_sparse_vecs
# 6. --- Update manifest and save everything ---
print("Updating manifest file...")
start_index = len(cached_manifest)
updated_manifest = cached_manifest.copy()
for i, res in enumerate(resolutions_to_encode):
updated_manifest[res['id']] = start_index + i
try:
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save combined embeddings
np.save(DENSE_OUTPUT_PATH, combined_dense_embeddings)
print(f"Saved combined semantic embeddings to {DENSE_OUTPUT_PATH} (Shape: {combined_dense_embeddings.shape})")
np.save(SPARSE_OUTPUT_PATH, combined_sparse_embeddings, allow_pickle=True)
print(
f"Saved combined loose embeddings to {SPARSE_OUTPUT_PATH} (Total objects: {len(combined_sparse_embeddings)})")
# Save the updated manifest
with open(MANIFEST_PATH, 'w', encoding='utf-8') as f:
json.dump(updated_manifest, f, indent=2)
print(f"Saved updated manifest to {MANIFEST_PATH}")
print("\nGA Resolution embedding process complete!")
except Exception as e:
print(f"An error occurred while saving the files: {e}")
# Call the function to start the embedding process
if __name__ == "__main__":
encode_ga_resolutions_with_caching() |