final_project2 / src /vector_store.py
dnj0's picture
Update src/vector_store.py
e4ac86d verified
import os
import json
from typing import List, Dict
import chromadb
from sentence_transformers import SentenceTransformer
import numpy as np
from config import CHROMA_DB_PATH, EMBEDDING_MODEL, EMBEDDING_DIM
class CLIPEmbedder:
def __init__(self, model_name: str = EMBEDDING_MODEL):
print(f"πŸ”„ Loading embedding model: {model_name}")
self.model = SentenceTransformer(model_name)
print(f"βœ… Model loaded successfully")
def embed(self, text: str) -> List[float]:
try:
embedding = self.model.encode(text, convert_to_numpy=False)
return embedding.tolist() if hasattr(embedding, 'tolist') else embedding
except Exception as e:
print(f"Error embedding text: {e}")
return [0.0] * EMBEDDING_DIM
def embed_batch(self, texts: List[str]) -> List[List[float]]:
try:
embeddings = self.model.encode(texts, convert_to_numpy=False)
return [e.tolist() if hasattr(e, 'tolist') else e for e in embeddings]
except Exception as e:
print(f"Error embedding batch: {e}")
return [[0.0] * EMBEDDING_DIM] * len(texts)
class VectorStore:
def __init__(self):
self.persist_directory = CHROMA_DB_PATH
self.embedder = CLIPEmbedder()
print(f"\nπŸ”„ Initializing ChromaDB at: {self.persist_directory}")
try:
self.client = chromadb.PersistentClient(
path=self.persist_directory
)
print(f"βœ… ChromaDB PersistentClient initialized")
except Exception as e:
print(f"❌ Error initializing ChromaDB: {e}")
print(f"Trying fallback initialization...")
self.client = chromadb.PersistentClient(
path=self.persist_directory
)
try:
self.collection = self.client.get_or_create_collection(
name="multimodal_rag",
metadata={"hnsw:space": "cosine"}
)
count = self.collection.count()
print(f"βœ… Collection loaded: {count} items in store")
except Exception as e:
print(f"Error with collection: {e}")
self.collection = self.client.get_or_create_collection(
name="multimodal_rag"
)
def add_documents(self, documents: List[Dict], doc_id: str):
texts = []
metadatas = []
ids = []
print(f"\nπŸ“š Adding documents for: {doc_id}")
if 'text' in documents and documents['text']:
chunks = self._chunk_text(documents['text'], chunk_size=1000, overlap=200)
for idx, chunk in enumerate(chunks):
texts.append(chunk)
metadatas.append({
'doc_id': doc_id,
'type': 'text',
'chunk_idx': str(idx)
})
ids.append(f"{doc_id}_text_{idx}")
print(f" βœ… Text: {len(chunks)} chunks")
if 'images' in documents:
image_count = 0
for idx, image_data in enumerate(documents['images']):
if image_data.get('ocr_text'):
texts.append(f"Image {idx}: {image_data['ocr_text']}")
metadatas.append({
'doc_id': doc_id,
'type': 'image',
'image_idx': str(idx),
'image_path': image_data.get('path', '')
})
ids.append(f"{doc_id}_image_{idx}")
image_count += 1
if image_count > 0:
print(f" βœ… Images: {image_count} with OCR text")
if 'tables' in documents:
table_count = 0
for idx, table_data in enumerate(documents['tables']):
if table_data.get('content'):
texts.append(f"Table {idx}: {table_data.get('content', '')}")
metadatas.append({
'doc_id': doc_id,
'type': 'table',
'table_idx': str(idx)
})
ids.append(f"{doc_id}_table_{idx}")
table_count += 1
if table_count > 0:
print(f" βœ… Tables: {table_count}")
if texts:
print(f" πŸ”„ Generating {len(texts)} embeddings...")
embeddings = self.embedder.embed_batch(texts)
try:
self.collection.add(
ids=ids,
documents=texts,
embeddings=embeddings,
metadatas=metadatas
)
print(f"βœ… Successfully added {len(texts)} items to vector store")
print(f"βœ… Data persisted automatically to: {self.persist_directory}")
except Exception as e:
print(f"❌ Error adding to collection: {e}")
def search(self, query: str, n_results: int = 5) -> List[Dict]:
try:
query_embedding = self.embedder.embed(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
formatted_results = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
metadata = results['metadatas'][0][i] if results['metadatas'] else {}
distance = results['distances'][0][i] if results['distances'] else 0
formatted_results.append({
'content': doc,
'metadata': metadata,
'distance': distance,
'type': metadata.get('type', 'unknown')
})
return formatted_results
except Exception as e:
print(f"Error searching vector store: {e}")
return []
def _chunk_text(self, text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start = end - overlap
return chunks
def get_collection_info(self) -> Dict:
try:
count = self.collection.count()
return {
'name': 'multimodal_rag',
'count': count,
'status': 'active',
'persist_path': self.persist_directory
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {'status': 'error', 'message': str(e)}
def delete_by_doc_id(self, doc_id: str):
try:
# Get all IDs with this doc_id
results = self.collection.get(where={'doc_id': doc_id})
if results['ids']:
self.collection.delete(ids=results['ids'])
print(f"βœ… Deleted {len(results['ids'])} documents for {doc_id}")
# Auto-persist on delete
print(f"βœ… Changes persisted automatically")
except Exception as e:
print(f"Error deleting documents: {e}")
def persist(self):
print("βœ… Vector store is using auto-persist")
def clear_all(self):
try:
self.client.delete_collection(name="multimodal_rag")
self.collection = self.client.get_or_create_collection(
name="multimodal_rag",
metadata={"hnsw:space": "cosine"}
)
print("βœ… Collection cleared and reset")
except Exception as e:
print(f"Error clearing collection: {e}")