ns_issue_search / small_scripts /make_embedding /embedding_ga_resolutions.py
Bohaska
update GA resolution scripts to use API
7392937
import os
import json
import numpy as np
from FlagEmbedding import BGEM3FlagModel
# Determine the directory of the script to load files relative to it
script_dir = os.path.dirname(os.path.abspath(__file__))
# --- Configuration ---
# IMPORTANT: Adjust MODEL_PATH to your model's actual local path.
MODEL_PATH = '../../../../Downloads/bge-m3'
# Path to the input JSON file for GA resolutions.
GA_RESOLUTIONS_JSON_PATH = os.path.join(script_dir, '..', '..', 'parsed_ga_resolutions.json')
# Output directory for the generated files.
OUTPUT_DIR = os.path.join(script_dir, '..', '..')
# --- Output and Cache File Paths ---
DENSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_semantic_bge-m3.npy')
SPARSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_loose_bge-m3.npy')
MANIFEST_PATH = os.path.join(script_dir, 'embeddings_manifest.json') # New manifest file
# --- Main Embedding Function ---
def encode_ga_resolutions_with_caching():
# 1. --- Load the source of truth: all resolutions ---
print(f"Loading all GA resolutions from: {GA_RESOLUTIONS_JSON_PATH}")
try:
with open(GA_RESOLUTIONS_JSON_PATH, 'r', encoding='utf-8') as file:
all_resolutions_data = json.load(file)
# Filter out resolutions without a valid body
all_resolutions_data = [
r for r in all_resolutions_data if 'id' in r and 'body' in r and r['body'].strip()
]
except (FileNotFoundError, json.JSONDecodeError, Exception) as e:
print(f"Fatal Error: Could not load or parse the source resolutions file. Cannot proceed. Error: {e}")
return
# 2. --- Load existing cache (manifest and embeddings) ---
cached_manifest = {}
old_dense_embeddings = None
old_sparse_embeddings = None
if os.path.exists(MANIFEST_PATH) and os.path.exists(DENSE_OUTPUT_PATH) and os.path.exists(SPARSE_OUTPUT_PATH):
print("Found existing cache. Loading manifest and embeddings.")
try:
with open(MANIFEST_PATH, 'r', encoding='utf-8') as f:
cached_manifest = json.load(f)
# Convert string keys from JSON back to integers if necessary
cached_manifest = {int(k): v for k, v in cached_manifest.items()}
old_dense_embeddings = np.load(DENSE_OUTPUT_PATH)
old_sparse_embeddings = np.load(SPARSE_OUTPUT_PATH, allow_pickle=True)
print(f"Successfully loaded cache for {len(cached_manifest)} resolutions.")
except Exception as e:
print(f"Warning: Could not load cache files correctly: {e}. Re-embedding all resolutions.")
cached_manifest = {} # Reset if cache is corrupt
else:
print("No existing cache found. Will generate embeddings for all resolutions.")
# 3. --- Identify new resolutions to be encoded ---
all_res_ids = {r['id'] for r in all_resolutions_data}
cached_res_ids = set(cached_manifest.keys())
new_res_ids = all_res_ids - cached_res_ids
if not new_res_ids:
print("All resolutions are already embedded. Nothing to do. Exiting.")
return
print(f"Found {len(new_res_ids)} new resolutions to embed.")
resolutions_to_encode = [r for r in all_resolutions_data if r['id'] in new_res_ids]
# Sort by ID to ensure a consistent order
resolutions_to_encode.sort(key=lambda x: x['id'])
new_texts = [r['body'] for r in resolutions_to_encode]
# 4. --- Initialize model and encode ONLY the new data ---
print("Initializing BGEM3FlagModel...")
try:
model = BGEM3FlagModel(MODEL_PATH, use_fp16=True)
print("Model loaded.")
except Exception as e:
print(f"Error loading model from {MODEL_PATH}: {e}")
return
print(f"Encoding {len(new_texts)} new resolutions (dense, sparse)...")
try:
new_embeddings = model.encode(new_texts,
batch_size=8,
max_length=8192,
return_dense=True,
return_sparse=True,
return_colbert_vecs=False)
except Exception as e:
print(f"An error occurred during embedding generation: {e}")
import traceback
traceback.print_exc()
return
# 5. --- Combine old and new embeddings ---
new_dense_vecs = new_embeddings['dense_vecs']
new_sparse_list = new_embeddings['lexical_weights']
new_sparse_vecs = np.array(new_sparse_list, dtype=object)
if old_dense_embeddings is not None and old_sparse_embeddings is not None:
print("Combining new embeddings with cached ones...")
combined_dense_embeddings = np.vstack([old_dense_embeddings, new_dense_vecs])
combined_sparse_embeddings = np.concatenate([old_sparse_embeddings, new_sparse_vecs])
else:
# This branch is for the first run when no cache exists
combined_dense_embeddings = new_dense_vecs
combined_sparse_embeddings = new_sparse_vecs
# 6. --- Update manifest and save everything ---
print("Updating manifest file...")
start_index = len(cached_manifest)
updated_manifest = cached_manifest.copy()
for i, res in enumerate(resolutions_to_encode):
updated_manifest[res['id']] = start_index + i
try:
# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save combined embeddings
np.save(DENSE_OUTPUT_PATH, combined_dense_embeddings)
print(f"Saved combined semantic embeddings to {DENSE_OUTPUT_PATH} (Shape: {combined_dense_embeddings.shape})")
np.save(SPARSE_OUTPUT_PATH, combined_sparse_embeddings, allow_pickle=True)
print(
f"Saved combined loose embeddings to {SPARSE_OUTPUT_PATH} (Total objects: {len(combined_sparse_embeddings)})")
# Save the updated manifest
with open(MANIFEST_PATH, 'w', encoding='utf-8') as f:
json.dump(updated_manifest, f, indent=2)
print(f"Saved updated manifest to {MANIFEST_PATH}")
print("\nGA Resolution embedding process complete!")
except Exception as e:
print(f"An error occurred while saving the files: {e}")
# Call the function to start the embedding process
if __name__ == "__main__":
encode_ga_resolutions_with_caching()