studyrag / app /services /rag_service.py
beerohan
Flatten directory structure for deployment
5ac3946
import os
from typing import AsyncGenerator, List, Optional
from threading import Lock
import chromadb
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
from llama_index.core.chat_engine.types import BaseChatEngine
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.llms.groq import Groq
from llama_index.vector_stores.chroma import ChromaVectorStore
from app.config import settings
from app.models.schemas import SourceInfo
_llm_lock = Lock()
_llm_initialized = False
def _ensure_llm() -> None:
global _llm_initialized
with _llm_lock:
if _llm_initialized:
return
os.environ["GROQ_API_KEY"] = settings.groq_api_key
Settings.llm = Groq(model=settings.groq_model, api_key=settings.groq_api_key)
Settings.embed_model = FastEmbedEmbedding(model_name=settings.embed_model)
_llm_initialized = True
class RAGService:
"""Persistent, multi-session RAG service backed by Chroma."""
def __init__(self) -> None:
settings.chroma_dir.mkdir(parents=True, exist_ok=True)
self._chroma = chromadb.PersistentClient(path=str(settings.chroma_dir))
self._collection = self._chroma.get_or_create_collection("studyson")
self._vector_store = ChromaVectorStore(chroma_collection=self._collection)
self._storage = StorageContext.from_defaults(vector_store=self._vector_store)
self._index: Optional[VectorStoreIndex] = None
self._chat_engines: dict[str, BaseChatEngine] = {}
self._indexed_documents: list[str] = self._load_indexed_documents()
def _load_indexed_documents(self) -> list[str]:
try:
data = self._collection.get(include=["metadatas"])
sources = {m.get("source") for m in (data.get("metadatas") or []) if m}
return sorted(s for s in sources if s)
except Exception:
return []
def _ensure_index(self) -> VectorStoreIndex:
_ensure_llm()
if self._index is None:
self._index = VectorStoreIndex.from_vector_store(
vector_store=self._vector_store,
storage_context=self._storage,
)
return self._index
def add_document(self, text: str, source_name: str) -> None:
index = self._ensure_index()
document = Document(text=text, metadata={"source": source_name})
index.insert(document)
if source_name not in self._indexed_documents:
self._indexed_documents.append(source_name)
self._chat_engines.clear()
def _get_chat_engine(self, session_id: str) -> BaseChatEngine:
engine = self._chat_engines.get(session_id)
if engine is None:
index = self._ensure_index()
engine = index.as_chat_engine(
chat_mode="condense_plus_context",
similarity_top_k=settings.similarity_top_k,
verbose=False,
)
self._chat_engines[session_id] = engine
return engine
async def stream_query(
self, question: str, session_id: str
) -> AsyncGenerator[str, None]:
if not self.has_documents():
raise ValueError("No documents indexed.")
engine = self._get_chat_engine(session_id)
response = await engine.astream_chat(question)
async for token in response.async_response_gen():
yield token
async def query(self, question: str) -> tuple[str, List[SourceInfo]]:
if not self.has_documents():
raise ValueError("No documents indexed.")
index = self._ensure_index()
query_engine = index.as_query_engine(similarity_top_k=settings.similarity_top_k)
response = await query_engine.aquery(question)
sources: list[SourceInfo] = []
for node in getattr(response, "source_nodes", []) or []:
sources.append(
SourceInfo(
file_name=node.metadata.get("source", "Unknown"),
text=node.text[:300],
score=getattr(node, "score", None),
)
)
return str(response), sources
async def summarize(self, max_length: int = 500) -> str:
if not self.has_documents():
raise ValueError("No documents indexed.")
index = self._ensure_index()
query_engine = index.as_query_engine(similarity_top_k=8)
prompt = (
f"Provide a comprehensive summary of all indexed documents in approximately "
f"{max_length} words. Cover the main ideas, key arguments, and important details. "
f"Use clear paragraphs."
)
response = await query_engine.aquery(prompt)
return str(response)
def reset_session(self, session_id: str) -> None:
self._chat_engines.pop(session_id, None)
def reset_all(self) -> None:
try:
self._chroma.delete_collection("studyson")
except Exception:
pass
self._collection = self._chroma.get_or_create_collection("studyson")
self._vector_store = ChromaVectorStore(chroma_collection=self._collection)
self._storage = StorageContext.from_defaults(vector_store=self._vector_store)
self._index = None
self._chat_engines.clear()
self._indexed_documents = []
def get_indexed_documents(self) -> List[str]:
return list(self._indexed_documents)
def has_documents(self) -> bool:
try:
return self._collection.count() > 0
except Exception:
return bool(self._indexed_documents)