import os from typing import AsyncGenerator, List, Optional from threading import Lock import chromadb from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex from llama_index.core.chat_engine.types import BaseChatEngine from llama_index.embeddings.fastembed import FastEmbedEmbedding from llama_index.llms.groq import Groq from llama_index.vector_stores.chroma import ChromaVectorStore from app.config import settings from app.models.schemas import SourceInfo _llm_lock = Lock() _llm_initialized = False def _ensure_llm() -> None: global _llm_initialized with _llm_lock: if _llm_initialized: return os.environ["GROQ_API_KEY"] = settings.groq_api_key Settings.llm = Groq(model=settings.groq_model, api_key=settings.groq_api_key) Settings.embed_model = FastEmbedEmbedding(model_name=settings.embed_model) _llm_initialized = True class RAGService: """Persistent, multi-session RAG service backed by Chroma.""" def __init__(self) -> None: settings.chroma_dir.mkdir(parents=True, exist_ok=True) self._chroma = chromadb.PersistentClient(path=str(settings.chroma_dir)) self._collection = self._chroma.get_or_create_collection("studyson") self._vector_store = ChromaVectorStore(chroma_collection=self._collection) self._storage = StorageContext.from_defaults(vector_store=self._vector_store) self._index: Optional[VectorStoreIndex] = None self._chat_engines: dict[str, BaseChatEngine] = {} self._indexed_documents: list[str] = self._load_indexed_documents() def _load_indexed_documents(self) -> list[str]: try: data = self._collection.get(include=["metadatas"]) sources = {m.get("source") for m in (data.get("metadatas") or []) if m} return sorted(s for s in sources if s) except Exception: return [] def _ensure_index(self) -> VectorStoreIndex: _ensure_llm() if self._index is None: self._index = VectorStoreIndex.from_vector_store( vector_store=self._vector_store, storage_context=self._storage, ) return self._index def add_document(self, text: str, source_name: str) -> None: index = self._ensure_index() document = Document(text=text, metadata={"source": source_name}) index.insert(document) if source_name not in self._indexed_documents: self._indexed_documents.append(source_name) self._chat_engines.clear() def _get_chat_engine(self, session_id: str) -> BaseChatEngine: engine = self._chat_engines.get(session_id) if engine is None: index = self._ensure_index() engine = index.as_chat_engine( chat_mode="condense_plus_context", similarity_top_k=settings.similarity_top_k, verbose=False, ) self._chat_engines[session_id] = engine return engine async def stream_query( self, question: str, session_id: str ) -> AsyncGenerator[str, None]: if not self.has_documents(): raise ValueError("No documents indexed.") engine = self._get_chat_engine(session_id) response = await engine.astream_chat(question) async for token in response.async_response_gen(): yield token async def query(self, question: str) -> tuple[str, List[SourceInfo]]: if not self.has_documents(): raise ValueError("No documents indexed.") index = self._ensure_index() query_engine = index.as_query_engine(similarity_top_k=settings.similarity_top_k) response = await query_engine.aquery(question) sources: list[SourceInfo] = [] for node in getattr(response, "source_nodes", []) or []: sources.append( SourceInfo( file_name=node.metadata.get("source", "Unknown"), text=node.text[:300], score=getattr(node, "score", None), ) ) return str(response), sources async def summarize(self, max_length: int = 500) -> str: if not self.has_documents(): raise ValueError("No documents indexed.") index = self._ensure_index() query_engine = index.as_query_engine(similarity_top_k=8) prompt = ( f"Provide a comprehensive summary of all indexed documents in approximately " f"{max_length} words. Cover the main ideas, key arguments, and important details. " f"Use clear paragraphs." ) response = await query_engine.aquery(prompt) return str(response) def reset_session(self, session_id: str) -> None: self._chat_engines.pop(session_id, None) def reset_all(self) -> None: try: self._chroma.delete_collection("studyson") except Exception: pass self._collection = self._chroma.get_or_create_collection("studyson") self._vector_store = ChromaVectorStore(chroma_collection=self._collection) self._storage = StorageContext.from_defaults(vector_store=self._vector_store) self._index = None self._chat_engines.clear() self._indexed_documents = [] def get_indexed_documents(self) -> List[str]: return list(self._indexed_documents) def has_documents(self) -> bool: try: return self._collection.count() > 0 except Exception: return bool(self._indexed_documents)