Krishna172912 commited on
Commit
6017881
·
unverified ·
1 Parent(s): 3246962

Create embeddings.py

Browse files
Files changed (1) hide show
  1. back_end/core/embeddings.py +115 -0
back_end/core/embeddings.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from langchain_core.embeddings import Embeddings
3
+ from typing import List
4
+ from langchain_chroma import Chroma
5
+ from langchain_core.documents import Document
6
+ from sentence_transformers import SentenceTransformer
7
+ import uuid # to generate ids
8
+ from config import CHROMA_PERSIST_DIR,CHROMA_COLLECTION_NAME
9
+ import os
10
+ import shutil
11
+ from core.downloader import delete_dir
12
+
13
+ #This is fix for issue with model SFR-Embedding-Code-400M_R while working with latest RTX5050
14
+ def _inject_position_ids_hook(module, args, kwargs):
15
+ if 'attention_mask' in kwargs and 'position_ids' not in kwargs:
16
+ attention_mask = kwargs['attention_mask']
17
+ position_ids = (attention_mask.long().cumsum(-1) - 1)
18
+ position_ids.masked_fill_(attention_mask == 0, 0)
19
+ kwargs['position_ids'] = position_ids
20
+ return args, kwargs
21
+
22
+
23
+ class _SFRCodeEmbeddings(Embeddings):
24
+
25
+ #instruction prefix specified by the Salesforce AI Research team
26
+ QUERY_INSTRUCTION = "Instruct: Given Code or Text, retrieve relevant content. Query: "
27
+
28
+ def __init__(self, model_path='Salesforce/SFR-Embedding-Code-400M_R'):
29
+ print("Loading local SFR Code Model to GPU via ST...")
30
+
31
+ self.model = SentenceTransformer(model_path, device='cuda', trust_remote_code=True)
32
+ self.model.max_seq_length = 1024
33
+ self.model[0].auto_model.register_forward_pre_hook(_inject_position_ids_hook, with_kwargs=True)
34
+
35
+ print("Model loaded and position_ids hook attached!")
36
+
37
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
38
+ embeddings = self.model.encode(
39
+ texts,
40
+ batch_size=60,
41
+ show_progress_bar=True,
42
+ normalize_embeddings=True,
43
+ )
44
+ return embeddings.tolist()
45
+
46
+ def embed_query(self, text: str) -> List[float]:
47
+ # The query MUST have the exact instruction prefix applied before encoding
48
+ prefixed_query = self.QUERY_INSTRUCTION + text
49
+ embeddings = self.model.encode(
50
+ [prefixed_query],
51
+ batch_size=1,
52
+ show_progress_bar=False,
53
+ normalize_embeddings=True,
54
+ )
55
+ return embeddings[0].tolist()
56
+
57
+
58
+ def _custom_add_document(vector_db: Chroma, documents: List[Document]):
59
+ texts = [doc.page_content for doc in documents]
60
+ metadatas = [doc.metadata for doc in documents]
61
+ ids = [str(uuid.uuid4()) for _ in range(len(documents))]
62
+
63
+ print(f"Running Global Smart Batching on GPU for {len(texts)} documents...")
64
+ all_embeddings = vector_db.embeddings.embed_documents(texts)
65
+
66
+ CHROMA_BATCH_SIZE = 5000
67
+ print("Inserting into ChromaDB...")
68
+ collection = vector_db._collection
69
+
70
+ for i in range(0, len(texts), CHROMA_BATCH_SIZE):
71
+ batch_texts = texts[i : i + CHROMA_BATCH_SIZE]
72
+ batch_metadatas = metadatas[i : i + CHROMA_BATCH_SIZE]
73
+ batch_embeddings = all_embeddings[i : i + CHROMA_BATCH_SIZE]
74
+ batch_ids = ids[i : i + CHROMA_BATCH_SIZE]
75
+
76
+ collection.add(
77
+ documents=batch_texts,
78
+ metadatas=batch_metadatas,
79
+ embeddings=batch_embeddings,
80
+ ids=batch_ids,
81
+ )
82
+ print(f"Successfully inserted documents {i} through {i + len(batch_texts)}")
83
+
84
+
85
+ def build_vector_db(documents: List[Document]) -> Chroma:
86
+ """Wipes the old DB and builds a fresh one."""
87
+ # 1. Cleanup previous database
88
+ if os.path.exists(CHROMA_PERSIST_DIR):
89
+ print("Cleaning up old vector database...")
90
+ delete_dir(CHROMA_PERSIST_DIR)
91
+
92
+ # 2. Initialize new database
93
+ local_embedding_fn = _SFRCodeEmbeddings()
94
+ vector_db = Chroma(
95
+ persist_directory=CHROMA_PERSIST_DIR,
96
+ embedding_function=local_embedding_fn,
97
+ collection_name=CHROMA_COLLECTION_NAME,
98
+ )
99
+
100
+ # 3. Add documents using our custom batcher
101
+ if documents:
102
+ _custom_add_document(vector_db, documents)
103
+
104
+ return vector_db
105
+
106
+ #to get stored vector_bd used in agent/tools.py
107
+ def get_vector_db() -> Chroma:
108
+ """Loads the EXISTING database (Used by the Agent/Tools)."""
109
+ local_embedding_fn = _SFRCodeEmbeddings()
110
+ vector_db = Chroma(
111
+ persist_directory=CHROMA_PERSIST_DIR,
112
+ embedding_function=local_embedding_fn,
113
+ collection_name=CHROMA_COLLECTION_NAME,
114
+ )
115
+ return vector_db