File size: 5,709 Bytes
b6e1b94
 
 
 
c07baa6
 
 
 
 
b6e1b94
c07baa6
 
 
 
 
 
 
b6e1b94
 
c07baa6
 
 
 
 
 
 
b6e1b94
c07baa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e1b94
 
 
c07baa6
b6e1b94
c07baa6
 
b6e1b94
c07baa6
 
 
 
 
b6e1b94
 
c07baa6
b6e1b94
 
 
c07baa6
b6e1b94
ab97519
 
c07baa6
 
 
ab97519
c07baa6
 
ab97519
 
c07baa6
 
ab97519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e1b94
 
c07baa6
 
 
 
b6e1b94
c07baa6
 
 
 
 
 
 
ab97519
c07baa6
 
 
 
 
 
b6e1b94
 
c07baa6
b6e1b94
 
c07baa6
 
 
 
 
b6e1b94
 
 
c07baa6
 
 
b6e1b94
c07baa6
 
 
b6e1b94
 
 
 
c07baa6
 
b6e1b94
c07baa6
 
 
 
 
 
b6e1b94
 
 
c07baa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import pickle
import faiss
import numpy as np
from typing import List, Optional
import logging
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel

def load_faiss_index(index_path: str):
    """
    Loads a binary FAISS index from disk.
    """
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index not found at {index_path}")
    return faiss.read_index(index_path)


def normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:
    """
    Applies L2 normalization to embeddings. 
    This converts Euclidean distance search into Cosine Similarity search.
    """
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

def save_faiss_index(index: faiss.Index, index_path: str):
    """
    Saves the FAISS binary index to disk.
    """
    try:
        os.makedirs(os.path.dirname(index_path), exist_ok=True)
        faiss.write_index(index, index_path)
        logging.info(f"Successfully saved FAISS index to {index_path}")
    except Exception as e:
        logging.error(f"Failed to save FAISS index: {e}")
        raise

def save_metadata(metadata: list, meta_path: str):
    """
    Saves the document metadata list using pickle.
    """
    try:
        os.makedirs(os.path.dirname(meta_path), exist_ok=True)
        with open(meta_path, 'wb') as f:
            pickle.dump(metadata, f)
        logging.info(f"Successfully saved metadata to {meta_path}")
    except Exception as e:
        logging.error(f"Failed to save metadata: {e}")
        raise


def add_embeddings_to_index(index_path: str, embeddings: np.ndarray):
    """Optimized FAISS management with unit normalization for Cosine Similarity."""
    embeddings = embeddings.astype('float32')
    # 1. Always normalize for 'Inner Product' to simulate Cosine Similarity
    faiss.normalize_L2(embeddings)
    dim = embeddings.shape[1]

    if os.path.exists(index_path):
        idx = faiss.read_index(index_path)
        if idx.d != dim:
            raise ValueError(f"Dimension mismatch: Index {idx.d} vs New {dim}")
        idx.add(embeddings)
    else:
        # 2. Use IndexFlatIP (Inner Product) + Normalization = Cosine Similarity
        idx = faiss.IndexFlatIP(dim)
        idx.add(embeddings)

    faiss.write_index(idx, index_path)


def append_metadata(meta_path: str, new_meta: list) -> int:
    """
    Efficiently appends to a pickle file using 'ab' (append binary) mode.
    This avoids loading the entire existing metadata list into memory.
    And returns the TOTAL count of chunks in the file.
    """
    os.makedirs(os.path.dirname(meta_path), exist_ok=True)
    
    # 1. Perform the append
    with open(meta_path, "ab") as f:
        pickle.dump(new_meta, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    # 2. Calculate the total size by reading the "stacked" objects
    total_count = 0
    try:
        with open(meta_path, "rb") as f:
            while True:
                try:
                    data = pickle.load(f)
                    # If data is a list, add its length; if it's a single dict, add 1
                    total_count += len(data) if isinstance(data, list) else 1
                except EOFError:
                    break
    except Exception as e:
        logging.error(f"Error calculating metadata size: {e}")
        
    logging.info(f"Total metadata chunks after append: {total_count}")
    return total_count


def load_metadata(path: str) -> list:
    """Loads all objects from an appended pickle file into a single flat list."""
    all_data = []
    if not os.path.exists(path):
        return []
    with open(path, "rb") as f:
        while True:
            try:
                all_data.extend(pickle.load(f))
            except EOFError:
                break
    return all_data


def compute_embeddings(
    texts: List[str], 
    model_name: str = "nomic-ai/nomic-embed-text-v1", 
    batch_size: int = 32
) -> np.ndarray:
    if not texts:
        return np.zeros((0, 0), dtype='float32')

    # Path 1: SentenceTransformer (highly optimized)
    if SentenceTransformer is not None:
        try:
            # use device_map or cuda if available
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model = SentenceTransformer(model_name, device=device)
            # ST handles batching and progress internally
            return model.encode(texts, batch_size=batch_size, show_progress_bar=False).astype('float32')
        except Exception:
            pass

    # Path 2: HF Fallback with 'device_map' for memory safety
    if not all([AutoTokenizer, AutoModel, torch]):
        raise RuntimeError("Missing dependencies: torch, transformers")

    # Use 'auto' to shard large models across GPU/CPU automatically
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model_hf = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map="auto")

    all_emb = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        toks = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model_hf.device)
        
        with torch.no_grad():
            out = model_hf(**toks)
            # Faster Mean Pooling using built-in torch ops
            mask = toks["attention_mask"].unsqueeze(-1).expand(out.last_hidden_state.size()).float()
            summed = torch.sum(out.last_hidden_state * mask, 1)
            counts = torch.clamp(mask.sum(1), min=1e-9)
            emb_batch = (summed / counts).cpu().numpy()
            all_emb.append(emb_batch.astype('float32'))

    return np.vstack(all_emb)