WhatRepo / back_end /core /embeddings.py
Krishna172912's picture
Update embeddings.py
38105c3 unverified
import torch
from langchain_core.embeddings import Embeddings
from typing import List
from langchain_chroma import Chroma
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer
import uuid # to generate ids
from config import CHROMA_PERSIST_DIR,CHROMA_COLLECTION_NAME
import os
import shutil
from core.downloader import delete_dir
#This is fix for issue with model SFR-Embedding-Code-400M_R while working with latest RTX5050
def _inject_position_ids_hook(module, args, kwargs):
if 'attention_mask' in kwargs and 'position_ids' not in kwargs:
attention_mask = kwargs['attention_mask']
position_ids = (attention_mask.long().cumsum(-1) - 1)
position_ids.masked_fill_(attention_mask == 0, 0)
kwargs['position_ids'] = position_ids
return args, kwargs
class _SFRCodeEmbeddings(Embeddings):
#instruction prefix specified by the Salesforce AI Research team
QUERY_INSTRUCTION = "Instruct: Given Code or Text, retrieve relevant content. Query: "
def __init__(self, model_path='Salesforce/SFR-Embedding-Code-400M_R'):
print("Loading local SFR Code Model to GPU via ST...")
#self.model = SentenceTransformer(model_path, device='cuda', trust_remote_code=True)
# Automatically detect the hardware
hardware_device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_path, device=hardware_device, trust_remote_code=True)
self.model.max_seq_length = 1024
self.model[0].auto_model.register_forward_pre_hook(_inject_position_ids_hook, with_kwargs=True)
print("Model loaded and position_ids hook attached!")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = self.model.encode(
texts,
batch_size=60,
show_progress_bar=True,
normalize_embeddings=True,
)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
# The query MUST have the exact instruction prefix applied before encoding
prefixed_query = self.QUERY_INSTRUCTION + text
embeddings = self.model.encode(
[prefixed_query],
batch_size=1,
show_progress_bar=False,
normalize_embeddings=True,
)
return embeddings[0].tolist()
def _custom_add_document(vector_db: Chroma, documents: List[Document]):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
print(f"Running Global Smart Batching on GPU for {len(texts)} documents...")
all_embeddings = vector_db.embeddings.embed_documents(texts)
CHROMA_BATCH_SIZE = 5000
print("Inserting into ChromaDB...")
collection = vector_db._collection
for i in range(0, len(texts), CHROMA_BATCH_SIZE):
batch_texts = texts[i : i + CHROMA_BATCH_SIZE]
batch_metadatas = metadatas[i : i + CHROMA_BATCH_SIZE]
batch_embeddings = all_embeddings[i : i + CHROMA_BATCH_SIZE]
batch_ids = ids[i : i + CHROMA_BATCH_SIZE]
collection.add(
documents=batch_texts,
metadatas=batch_metadatas,
embeddings=batch_embeddings,
ids=batch_ids,
)
print(f"Successfully inserted documents {i} through {i + len(batch_texts)}")
def build_vector_db(documents: List[Document]) -> Chroma:
"""Wipes the old DB and builds a fresh one."""
# 1. Cleanup previous database
if os.path.exists(CHROMA_PERSIST_DIR):
print("Cleaning up old vector database...")
delete_dir(CHROMA_PERSIST_DIR)
# 2. Initialize new database
local_embedding_fn = _SFRCodeEmbeddings()
vector_db = Chroma(
persist_directory=CHROMA_PERSIST_DIR,
embedding_function=local_embedding_fn,
collection_name=CHROMA_COLLECTION_NAME,
)
# 3. Add documents using our custom batcher
if documents:
_custom_add_document(vector_db, documents)
return vector_db
#to get stored vector_bd used in agent/tools.py
def get_vector_db() -> Chroma:
"""Loads the EXISTING database (Used by the Agent/Tools)."""
local_embedding_fn = _SFRCodeEmbeddings()
vector_db = Chroma(
persist_directory=CHROMA_PERSIST_DIR,
embedding_function=local_embedding_fn,
collection_name=CHROMA_COLLECTION_NAME,
)
return vector_db