File size: 6,354 Bytes
3a373f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7392937
3a373f3
 
7392937
 
 
 
 
 
3a373f3
7392937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a373f3
 
 
 
 
 
 
 
7392937
3a373f3
7392937
 
 
 
 
 
3a373f3
7392937
 
 
3a373f3
 
7392937
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a373f3
 
 
 
 
7392937
 
 
3a373f3
7392937
 
 
3a373f3
7392937
 
 
 
3a373f3
7392937
3a373f3
 
7392937
 
3a373f3
 
 
7392937
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import json
import numpy as np
from FlagEmbedding import BGEM3FlagModel

# Determine the directory of the script to load files relative to it
script_dir = os.path.dirname(os.path.abspath(__file__))

# --- Configuration ---
# IMPORTANT: Adjust MODEL_PATH to your model's actual local path.
MODEL_PATH = '../../../../Downloads/bge-m3'

# Path to the input JSON file for GA resolutions.
GA_RESOLUTIONS_JSON_PATH = os.path.join(script_dir, '..', '..', 'parsed_ga_resolutions.json')

# Output directory for the generated files.
OUTPUT_DIR = os.path.join(script_dir, '..', '..')

# --- Output and Cache File Paths ---
DENSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_semantic_bge-m3.npy')
SPARSE_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'ns_ga_resolutions_loose_bge-m3.npy')
MANIFEST_PATH = os.path.join(script_dir, 'embeddings_manifest.json')  # New manifest file


# --- Main Embedding Function ---
def encode_ga_resolutions_with_caching():
    # 1. --- Load the source of truth: all resolutions ---
    print(f"Loading all GA resolutions from: {GA_RESOLUTIONS_JSON_PATH}")
    try:
        with open(GA_RESOLUTIONS_JSON_PATH, 'r', encoding='utf-8') as file:
            all_resolutions_data = json.load(file)
            # Filter out resolutions without a valid body
            all_resolutions_data = [
                r for r in all_resolutions_data if 'id' in r and 'body' in r and r['body'].strip()
            ]
    except (FileNotFoundError, json.JSONDecodeError, Exception) as e:
        print(f"Fatal Error: Could not load or parse the source resolutions file. Cannot proceed. Error: {e}")
        return

    # 2. --- Load existing cache (manifest and embeddings) ---
    cached_manifest = {}
    old_dense_embeddings = None
    old_sparse_embeddings = None

    if os.path.exists(MANIFEST_PATH) and os.path.exists(DENSE_OUTPUT_PATH) and os.path.exists(SPARSE_OUTPUT_PATH):
        print("Found existing cache. Loading manifest and embeddings.")
        try:
            with open(MANIFEST_PATH, 'r', encoding='utf-8') as f:
                cached_manifest = json.load(f)
                # Convert string keys from JSON back to integers if necessary
                cached_manifest = {int(k): v for k, v in cached_manifest.items()}

            old_dense_embeddings = np.load(DENSE_OUTPUT_PATH)
            old_sparse_embeddings = np.load(SPARSE_OUTPUT_PATH, allow_pickle=True)
            print(f"Successfully loaded cache for {len(cached_manifest)} resolutions.")
        except Exception as e:
            print(f"Warning: Could not load cache files correctly: {e}. Re-embedding all resolutions.")
            cached_manifest = {}  # Reset if cache is corrupt
    else:
        print("No existing cache found. Will generate embeddings for all resolutions.")

    # 3. --- Identify new resolutions to be encoded ---
    all_res_ids = {r['id'] for r in all_resolutions_data}
    cached_res_ids = set(cached_manifest.keys())
    new_res_ids = all_res_ids - cached_res_ids

    if not new_res_ids:
        print("All resolutions are already embedded. Nothing to do. Exiting.")
        return

    print(f"Found {len(new_res_ids)} new resolutions to embed.")
    resolutions_to_encode = [r for r in all_resolutions_data if r['id'] in new_res_ids]
    # Sort by ID to ensure a consistent order
    resolutions_to_encode.sort(key=lambda x: x['id'])

    new_texts = [r['body'] for r in resolutions_to_encode]

    # 4. --- Initialize model and encode ONLY the new data ---
    print("Initializing BGEM3FlagModel...")
    try:
        model = BGEM3FlagModel(MODEL_PATH, use_fp16=True)
        print("Model loaded.")
    except Exception as e:
        print(f"Error loading model from {MODEL_PATH}: {e}")
        return

    print(f"Encoding {len(new_texts)} new resolutions (dense, sparse)...")
    try:
        new_embeddings = model.encode(new_texts,
                                      batch_size=8,
                                      max_length=8192,
                                      return_dense=True,
                                      return_sparse=True,
                                      return_colbert_vecs=False)
    except Exception as e:
        print(f"An error occurred during embedding generation: {e}")
        import traceback
        traceback.print_exc()
        return

    # 5. --- Combine old and new embeddings ---
    new_dense_vecs = new_embeddings['dense_vecs']
    new_sparse_list = new_embeddings['lexical_weights']
    new_sparse_vecs = np.array(new_sparse_list, dtype=object)

    if old_dense_embeddings is not None and old_sparse_embeddings is not None:
        print("Combining new embeddings with cached ones...")
        combined_dense_embeddings = np.vstack([old_dense_embeddings, new_dense_vecs])
        combined_sparse_embeddings = np.concatenate([old_sparse_embeddings, new_sparse_vecs])
    else:
        # This branch is for the first run when no cache exists
        combined_dense_embeddings = new_dense_vecs
        combined_sparse_embeddings = new_sparse_vecs

    # 6. --- Update manifest and save everything ---
    print("Updating manifest file...")
    start_index = len(cached_manifest)
    updated_manifest = cached_manifest.copy()
    for i, res in enumerate(resolutions_to_encode):
        updated_manifest[res['id']] = start_index + i

    try:
        # Ensure output directory exists
        os.makedirs(OUTPUT_DIR, exist_ok=True)

        # Save combined embeddings
        np.save(DENSE_OUTPUT_PATH, combined_dense_embeddings)
        print(f"Saved combined semantic embeddings to {DENSE_OUTPUT_PATH} (Shape: {combined_dense_embeddings.shape})")

        np.save(SPARSE_OUTPUT_PATH, combined_sparse_embeddings, allow_pickle=True)
        print(
            f"Saved combined loose embeddings to {SPARSE_OUTPUT_PATH} (Total objects: {len(combined_sparse_embeddings)})")

        # Save the updated manifest
        with open(MANIFEST_PATH, 'w', encoding='utf-8') as f:
            json.dump(updated_manifest, f, indent=2)
        print(f"Saved updated manifest to {MANIFEST_PATH}")

        print("\nGA Resolution embedding process complete!")

    except Exception as e:
        print(f"An error occurred while saving the files: {e}")


# Call the function to start the embedding process
if __name__ == "__main__":
    encode_ga_resolutions_with_caching()