from __future__ import annotations import json import os import hashlib from pathlib import Path from typing import List import chromadb from chromadb.config import Settings from sentence_transformers import SentenceTransformer from src.utils.logger import get_logger from config.settings import settings logger = get_logger(__name__) class ChromaVectorDBManager: """Corporate-friendly ChromaDB manager - completely offline.""" def __init__(self, model_name: str = None, db_path: str = None): self.model_name = model_name or getattr( settings, 'EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2' ) self.embedding_model = SentenceTransformer(self.model_name) self.db_path = db_path or getattr(settings, 'CHROMADB_PATH', './chroma_db') os.makedirs(self.db_path, exist_ok=True) self.client = chromadb.PersistentClient( path=self.db_path, settings=Settings( anonymized_telemetry=False, allow_reset=True, is_persistent=True ) ) self.collection_name = getattr(settings, 'COLLECTION_NAME', 'rag_chunks') self.collection = self._get_collection() logger.info(f"ChromaDB initialized at: {self.db_path}") def _get_collection(self): """Get or create collection without embedding function.""" try: return self.client.get_collection(name=self.collection_name) except Exception: try: self.client.delete_collection(name=self.collection_name) except Exception: pass return self.client.create_collection( name=self.collection_name, metadata={"description": "RAG chunks"} ) def generate_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings using local sentence-transformers.""" embeddings = self.embedding_model.encode( texts, batch_size=32, show_progress_bar=len(texts) > 100, convert_to_tensor=False ) return embeddings.tolist() def add_chunks_to_db(self, chunks: list, source_file: str) -> bool: """Add chunks (list of dicts) to ChromaDB with manual embedding generation.""" if not chunks: return True texts, ids, metadatas = [], [], [] seen_hashes = set() for chunk in chunks: text = chunk.get("text", "").strip() if not text: continue text_hash = hashlib.md5(text.encode()).hexdigest() if text_hash in seen_hashes: continue seen_hashes.add(text_hash) chunk_id = f"{source_file}_{chunk.get('chunk_id', 0)}_{text_hash[:8]}" try: if self.collection.get(ids=[chunk_id])['ids']: continue except Exception: pass texts.append(text) ids.append(chunk_id) metadata = { "source_file": source_file, "chunk_index": chunk.get("chunk_id", 0), "char_length": len(text), "text_hash": text_hash } metadatas.append(metadata) if not texts: return True embeddings = self.generate_embeddings(texts) self.collection.add( embeddings=embeddings, documents=texts, metadatas=metadatas, ids=ids ) logger.info(f"Added {len(texts)} chunks from {source_file} to ChromaDB") return True def search_for_rag( self, query: str, n_results: int = 5, use_truncated: bool = False, filter_128_context: bool = False ) -> List[dict]: """Search using manual query embedding generation - completely offline.""" query_embedding = self.generate_embeddings([query])[0] results = self.collection.query( query_embeddings=[query_embedding], n_results=min(n_results * 2, 20), include=["documents", "metadatas", "distances"] ) search_results = [] for i, (doc, metadata, distance) in enumerate(zip( results['documents'][0], results['metadatas'][0], results['distances'][0] )): if len(search_results) >= n_results: break similarity_score = 1 / (1 + distance) result = { "id": results['ids'][0][i], "score": similarity_score, "distance": distance, "text": doc, "source_file": metadata["source_file"], "chunk_index": metadata["chunk_index"] } search_results.append(result) return search_results def reset_database(self): """Reset/delete existing collection.""" try: self.client.delete_collection(name=self.collection_name) self.collection = self._get_collection() logger.info(f"Reset collection: {self.collection_name}") return True except Exception as e: logger.error(f"Failed to reset database: {e}") return False def get_collection_stats(self) -> dict: """Get collection statistics.""" count = self.collection.count() db_size_mb = 0 try: for file_path in Path(self.db_path).rglob("*"): if file_path.is_file(): db_size_mb += file_path.stat().st_size db_size_mb /= (1024 * 1024) except Exception: db_size_mb = 0 return { "total_chunks": count, "collection_name": self.collection_name, "embedding_model": self.model_name, "db_path": self.db_path, "db_size_mb": db_size_mb } def process_all_chunks(self, chunks_dir: str = None) -> bool: """Process all *_extracted.json files and build ChromaDB.""" if not chunks_dir: chunks_dir = getattr(settings, 'PROCESSED_TEXT_DIR', './data/processed_text') chunk_files = list(Path(chunks_dir).glob("*_extracted.json")) logger.info(f"Found {len(chunk_files)} extracted JSON files to process") total_processed = 0 for chunk_file in chunk_files: try: with open(chunk_file, "r", encoding="utf-8") as f: data = json.load(f) # Handle the actual structure of extracted JSON files if isinstance(data, dict) and "initial_chunks" in data: # New format: { "source_info": {...}, "initial_chunks": [...] } chunks = data["initial_chunks"] elif isinstance(data, list): # Old format: list of chunks directly chunks = data else: logger.warning(f"Unexpected format in {chunk_file.name}") continue if chunks and self.add_chunks_to_db(chunks, source_file=chunk_file.name): total_processed += len(chunks) logger.info(f"Processed {chunk_file.name}: {len(chunks)} chunks") except Exception as e: logger.error(f"Error processing {chunk_file}: {e}") continue logger.info(f"Successfully processed {total_processed} total chunks") return total_processed > 0