AskXeno / src /vector_store.py
github-actions
Sync from GitHub
3cdce90
"""
Vector Store module for XENO Bot
Handles ChromaDB vector store operations
"""
from typing import Any, List, Tuple
import chromadb
import numpy as np
import torch
from langchain_chroma import Chroma
from sentence_transformers import util
from src.config import (CHROMA_DB_PATH, COLLECTION_NAME, EMBEDDING_MODEL,
RAG_MAX_RESULTS, RAG_TOP_K, genai_client)
from src.knowledge_base import get_knowledge_base_data
def initialize_vector_store() -> Tuple[chromadb.Collection, Chroma, Any]:
"""
Initialize ChromaDB vector store
Returns:
Tuple of (collection, vector_store, retriever)
"""
# Get knowledge base data
documents, metadatas, ids = get_knowledge_base_data()
# Initialize ChromaDB client
try:
client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
# Try to get existing collection
try:
collection = client.get_collection(name=COLLECTION_NAME)
print(f"Loaded existing ChromaDB collection: {COLLECTION_NAME}")
except:
# Create new collection if it doesn't exist
print(f"Creating new ChromaDB collection: {COLLECTION_NAME}")
collection = client.create_collection(name=COLLECTION_NAME)
collection.add(documents=documents, metadatas=metadatas, ids=ids)
# Create vector store and retriever
vector_store = Chroma(client=client, collection_name=COLLECTION_NAME)
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": RAG_TOP_K}
)
return collection, vector_store, retriever
except Exception as e:
print(f"Failed to initialize ChromaDB: {e}")
raise
def generate_embeddings(
query: str, documents: List[Any], timer=None
) -> Tuple[List[float], List[List[float]]]:
"""
Generate embeddings for query and documents
Args:
query: User query
documents: List of retrieved documents
timer: Optional timer object for tracking
Returns:
Tuple of (query_embedding, doc_embeddings)
"""
if timer:
with timer.time_step("embedding_generation"):
return _generate_embeddings_impl(query, documents)
else:
return _generate_embeddings_impl(query, documents)
def _generate_embeddings_impl(
query: str, documents: List[Any]
) -> Tuple[List[float], List[List[float]]]:
"""Internal implementation of embedding generation"""
# 1. Update query embedding access
query_result = genai_client.models.embed_content(
model=EMBEDDING_MODEL, contents=query
)
# The SDK returns an EmbedContentResponse object with an 'embeddings' attribute
query_embedding = query_result.embeddings[0].values
# 2. Update document embeddings access
doc_contents = [doc.page_content for doc in documents]
doc_results = genai_client.models.embed_content(
model=EMBEDDING_MODEL, contents=doc_contents
)
# Map the list of embedding objects to a list of vector values
doc_embeddings = [e.values for e in doc_results.embeddings]
return query_embedding, doc_embeddings
def calculate_similarity(
query_embedding: List[float], doc_embeddings: List[List[float]], timer=None
) -> List[float]:
"""
Calculate cosine similarity between query and documents
Args:
query_embedding: Query embedding vector
doc_embeddings: List of document embedding vectors
timer: Optional timer object for tracking
Returns:
List of cosine similarity scores
"""
if timer:
with timer.time_step("similarity_calculation"):
return _calculate_similarity_impl(query_embedding, doc_embeddings)
else:
return _calculate_similarity_impl(query_embedding, doc_embeddings)
def _calculate_similarity_impl(
query_embedding: List[float], doc_embeddings: List[List[float]]
) -> List[float]:
"""Internal implementation of similarity calculation"""
cosine_scores = util.cos_sim(
torch.tensor(query_embedding).float(), torch.tensor(doc_embeddings).float()
)[0].tolist()
return cosine_scores
def process_context(
results: List[Any],
cosine_scores: List[float],
max_results: int = RAG_MAX_RESULTS,
timer=None,
) -> Tuple[str, List[str], List[Tuple[str, str]]]:
"""
Process retrieved context and format for LLM
Args:
results: List of retrieved documents
cosine_scores: List of similarity scores
max_results: Maximum number of results to include
timer: Optional timer object for tracking
Returns:
Tuple of (formatted_context, source_ids, knowledge_pairs)
"""
if timer:
with timer.time_step("context_processing"):
return _process_context_impl(results, cosine_scores, max_results)
else:
return _process_context_impl(results, cosine_scores, max_results)
def _process_context_impl(
results: List[Any], cosine_scores: List[float], max_results: int
) -> Tuple[str, List[str], List[Tuple[str, str]]]:
"""Internal implementation of context processing"""
sorted_indices = np.argsort(cosine_scores)[::-1][:max_results]
formatted_context = ""
source_ids = []
knowledge_pairs = []
for i, idx in enumerate(sorted_indices, 1):
result = results[idx]
cosine_scores[idx]
question = result.metadata.get("question", "N/A")
answer = result.metadata.get("content", "N/A")
formatted_context += f"Knowledge Entry {i}:\n"
formatted_context += f"Q: {question}\n"
formatted_context += f"A: {answer}\n"
formatted_context += "-" * 40 + "\n"
source_ids.append(result.metadata.get("id", "N/A"))
knowledge_pairs.append((question, answer))
return formatted_context, source_ids, knowledge_pairs