File size: 4,519 Bytes
6017881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38105c3
 
 
 
6017881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from langchain_core.embeddings import Embeddings
from typing import List
from langchain_chroma import Chroma
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
import uuid  # to generate ids 
from config import CHROMA_PERSIST_DIR,CHROMA_COLLECTION_NAME
import os
import shutil
from core.downloader import delete_dir

#This is fix for issue with model SFR-Embedding-Code-400M_R while working with latest RTX5050 
def _inject_position_ids_hook(module, args, kwargs):
    if 'attention_mask' in kwargs and 'position_ids' not in kwargs:
        attention_mask = kwargs['attention_mask']
        position_ids = (attention_mask.long().cumsum(-1) - 1)
        position_ids.masked_fill_(attention_mask == 0, 0)
        kwargs['position_ids'] = position_ids
    return args, kwargs


class _SFRCodeEmbeddings(Embeddings):

    #instruction prefix specified by the Salesforce AI Research team
    QUERY_INSTRUCTION = "Instruct: Given Code or Text, retrieve relevant content. Query: "

    def __init__(self, model_path='Salesforce/SFR-Embedding-Code-400M_R'):
        print("Loading local SFR Code Model to GPU via ST...")

        #self.model = SentenceTransformer(model_path, device='cuda', trust_remote_code=True)
        # Automatically detect the hardware
        hardware_device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_path, device=hardware_device, trust_remote_code=True)
        self.model.max_seq_length = 1024
        self.model[0].auto_model.register_forward_pre_hook(_inject_position_ids_hook, with_kwargs=True)

        print("Model loaded and position_ids hook attached!")

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        embeddings = self.model.encode(
            texts,
            batch_size=60,
            show_progress_bar=True,
            normalize_embeddings=True,
        )
        return embeddings.tolist()

    def embed_query(self, text: str) -> List[float]:
        # The query MUST have the exact instruction prefix applied before encoding
        prefixed_query = self.QUERY_INSTRUCTION + text
        embeddings = self.model.encode(
            [prefixed_query],
            batch_size=1,
            show_progress_bar=False,
            normalize_embeddings=True,
        )
        return embeddings[0].tolist()

    
def _custom_add_document(vector_db: Chroma, documents: List[Document]):
    texts     = [doc.page_content for doc in documents]
    metadatas = [doc.metadata     for doc in documents]
    ids       = [str(uuid.uuid4()) for _ in range(len(documents))]

    print(f"Running Global Smart Batching on GPU for {len(texts)} documents...")
    all_embeddings = vector_db.embeddings.embed_documents(texts)

    CHROMA_BATCH_SIZE = 5000
    print("Inserting into ChromaDB...")
    collection = vector_db._collection

    for i in range(0, len(texts), CHROMA_BATCH_SIZE):
        batch_texts      = texts[i : i + CHROMA_BATCH_SIZE]
        batch_metadatas  = metadatas[i : i + CHROMA_BATCH_SIZE]
        batch_embeddings = all_embeddings[i : i + CHROMA_BATCH_SIZE]
        batch_ids        = ids[i : i + CHROMA_BATCH_SIZE]

        collection.add(
            documents=batch_texts,
            metadatas=batch_metadatas,
            embeddings=batch_embeddings,
            ids=batch_ids,
        )
        print(f"Successfully inserted documents {i} through {i + len(batch_texts)}")


def build_vector_db(documents: List[Document]) -> Chroma:
    """Wipes the old DB and builds a fresh one."""
    # 1. Cleanup previous database
    if os.path.exists(CHROMA_PERSIST_DIR):
        print("Cleaning up old vector database...")
        delete_dir(CHROMA_PERSIST_DIR)

    # 2. Initialize new database
    local_embedding_fn = _SFRCodeEmbeddings()
    vector_db = Chroma(
        persist_directory=CHROMA_PERSIST_DIR,
        embedding_function=local_embedding_fn,
        collection_name=CHROMA_COLLECTION_NAME,
    )

    # 3. Add documents using our custom batcher
    if documents:
        _custom_add_document(vector_db, documents)

    return vector_db

#to get stored vector_bd used in agent/tools.py
def get_vector_db() -> Chroma:
    """Loads the EXISTING database (Used by the Agent/Tools)."""
    local_embedding_fn = _SFRCodeEmbeddings()
    vector_db = Chroma(
        persist_directory=CHROMA_PERSIST_DIR,
        embedding_function=local_embedding_fn,
        collection_name=CHROMA_COLLECTION_NAME,
    )
    return vector_db