Spaces:
Sleeping
Sleeping
| import torch | |
| from langchain_core.embeddings import Embeddings | |
| from typing import List | |
| from langchain_chroma import Chroma | |
| from langchain_core.documents import Document | |
| from sentence_transformers import SentenceTransformer | |
| import uuid # to generate ids | |
| from config import CHROMA_PERSIST_DIR,CHROMA_COLLECTION_NAME | |
| import os | |
| import shutil | |
| from core.downloader import delete_dir | |
| #This is fix for issue with model SFR-Embedding-Code-400M_R while working with latest RTX5050 | |
| def _inject_position_ids_hook(module, args, kwargs): | |
| if 'attention_mask' in kwargs and 'position_ids' not in kwargs: | |
| attention_mask = kwargs['attention_mask'] | |
| position_ids = (attention_mask.long().cumsum(-1) - 1) | |
| position_ids.masked_fill_(attention_mask == 0, 0) | |
| kwargs['position_ids'] = position_ids | |
| return args, kwargs | |
| class _SFRCodeEmbeddings(Embeddings): | |
| #instruction prefix specified by the Salesforce AI Research team | |
| QUERY_INSTRUCTION = "Instruct: Given Code or Text, retrieve relevant content. Query: " | |
| def __init__(self, model_path='Salesforce/SFR-Embedding-Code-400M_R'): | |
| print("Loading local SFR Code Model to GPU via ST...") | |
| #self.model = SentenceTransformer(model_path, device='cuda', trust_remote_code=True) | |
| # Automatically detect the hardware | |
| hardware_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = SentenceTransformer(model_path, device=hardware_device, trust_remote_code=True) | |
| self.model.max_seq_length = 1024 | |
| self.model[0].auto_model.register_forward_pre_hook(_inject_position_ids_hook, with_kwargs=True) | |
| print("Model loaded and position_ids hook attached!") | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| embeddings = self.model.encode( | |
| texts, | |
| batch_size=60, | |
| show_progress_bar=True, | |
| normalize_embeddings=True, | |
| ) | |
| return embeddings.tolist() | |
| def embed_query(self, text: str) -> List[float]: | |
| # The query MUST have the exact instruction prefix applied before encoding | |
| prefixed_query = self.QUERY_INSTRUCTION + text | |
| embeddings = self.model.encode( | |
| [prefixed_query], | |
| batch_size=1, | |
| show_progress_bar=False, | |
| normalize_embeddings=True, | |
| ) | |
| return embeddings[0].tolist() | |
| def _custom_add_document(vector_db: Chroma, documents: List[Document]): | |
| texts = [doc.page_content for doc in documents] | |
| metadatas = [doc.metadata for doc in documents] | |
| ids = [str(uuid.uuid4()) for _ in range(len(documents))] | |
| print(f"Running Global Smart Batching on GPU for {len(texts)} documents...") | |
| all_embeddings = vector_db.embeddings.embed_documents(texts) | |
| CHROMA_BATCH_SIZE = 5000 | |
| print("Inserting into ChromaDB...") | |
| collection = vector_db._collection | |
| for i in range(0, len(texts), CHROMA_BATCH_SIZE): | |
| batch_texts = texts[i : i + CHROMA_BATCH_SIZE] | |
| batch_metadatas = metadatas[i : i + CHROMA_BATCH_SIZE] | |
| batch_embeddings = all_embeddings[i : i + CHROMA_BATCH_SIZE] | |
| batch_ids = ids[i : i + CHROMA_BATCH_SIZE] | |
| collection.add( | |
| documents=batch_texts, | |
| metadatas=batch_metadatas, | |
| embeddings=batch_embeddings, | |
| ids=batch_ids, | |
| ) | |
| print(f"Successfully inserted documents {i} through {i + len(batch_texts)}") | |
| def build_vector_db(documents: List[Document]) -> Chroma: | |
| """Wipes the old DB and builds a fresh one.""" | |
| # 1. Cleanup previous database | |
| if os.path.exists(CHROMA_PERSIST_DIR): | |
| print("Cleaning up old vector database...") | |
| delete_dir(CHROMA_PERSIST_DIR) | |
| # 2. Initialize new database | |
| local_embedding_fn = _SFRCodeEmbeddings() | |
| vector_db = Chroma( | |
| persist_directory=CHROMA_PERSIST_DIR, | |
| embedding_function=local_embedding_fn, | |
| collection_name=CHROMA_COLLECTION_NAME, | |
| ) | |
| # 3. Add documents using our custom batcher | |
| if documents: | |
| _custom_add_document(vector_db, documents) | |
| return vector_db | |
| #to get stored vector_bd used in agent/tools.py | |
| def get_vector_db() -> Chroma: | |
| """Loads the EXISTING database (Used by the Agent/Tools).""" | |
| local_embedding_fn = _SFRCodeEmbeddings() | |
| vector_db = Chroma( | |
| persist_directory=CHROMA_PERSIST_DIR, | |
| embedding_function=local_embedding_fn, | |
| collection_name=CHROMA_COLLECTION_NAME, | |
| ) | |
| return vector_db | |