sss-ai / utils /chroma.py
miyukicodes's picture
Update utils/chroma.py
15eef8b verified
import chromadb
import os
import json
from sentence_transformers import SentenceTransformer
class ChromaDBManager:
def __init__(self, persist_dir="sss_vectors"):
self.persist_dir = persist_dir
self.collection = None
self.is_loaded = False
self.client = None
self.embedder = None
def load(self):
"""Load existing ChromaDB"""
try:
# Check if running on HF Spaces
if os.getenv("SPACE_ID"):
print("Running on HF Spaces - Using in-memory ChromaDB")
# Use in-memory client for HF Spaces
self.client = chromadb.Client()
# Create collection with embedding function
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
self.collection = self.client.create_collection(
name="sss_documents",
metadata={"hnsw:space": "cosine"}
)
# Load documents from JSON export
self._load_from_json_export()
else:
# Use persistent client for local development
self.client = chromadb.PersistentClient(path=self.persist_dir)
self.collection = self.client.get_collection("sss_documents")
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
doc_count = self.collection.count()
print(f"✓ ChromaDB loaded with {doc_count} documents")
self.is_loaded = True if doc_count > 0 else False
return self.is_loaded
except Exception as e:
print(f"Error loading ChromaDB: {e}")
self.is_loaded = False
return False
def _load_from_json_export(self):
"""Load documents from JSON export file"""
json_path = "sss_documents_export.json"
if not os.path.exists(json_path):
print(f"Warning: {json_path} not found!")
return
print(f"Loading documents from {json_path}...")
try:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
documents = []
metadatas = []
ids = []
embeddings = []
for doc in data['documents']:
documents.append(doc['content'])
metadatas.append(doc.get('metadata', {}))
ids.append(doc['id'])
# Use existing embeddings or generate new ones
if doc.get('embedding'):
embeddings.append(doc['embedding'])
else:
# Generate embedding if not available
embedding = self.embedder.encode(doc['content'])
embeddings.append(embedding.tolist())
# Add to collection in batches
batch_size = 100
for i in range(0, len(documents), batch_size):
end_idx = min(i + batch_size, len(documents))
self.collection.add(
documents=documents[i:end_idx],
embeddings=embeddings[i:end_idx],
metadatas=metadatas[i:end_idx],
ids=ids[i:end_idx]
)
print(f"Added batch {i//batch_size + 1}/{(len(documents) + batch_size - 1)//batch_size}")
print(f"✓ Loaded {len(documents)} documents from JSON export")
except Exception as e:
print(f"Error loading from JSON: {e}")
def search(self, query_embedding, k=3):
"""Search for relevant documents"""
if not self.is_loaded or not self.collection:
return []
try:
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=k,
include=["documents", "metadatas", "distances"]
)
formatted_results = []
if results['documents'] and results['documents'][0]:
for doc, meta, dist in zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
):
formatted_results.append({
"content": doc,
"source": meta.get("source", "") if meta else "",
"relevance": 1 - dist
})
return formatted_results
except Exception as e:
print(f"Search error: {e}")
return []