abalone_chat_application / src /utils /rag_runtime.py
cmd0160's picture
Updating package file structure
18ef2cd
raw
history blame
2.68 kB
import sys
import subprocess
from typing import Any
import streamlit as st
from src.vectorstore import get_retriever
from src.qa_chain import make_conversational_chain
def run_ingest_cli(data_dir: str, persist_dir: str) -> None:
"""Run the ingestion module to rebuild the vectorstore.
Args:
data_dir: Directory containing the raw text files.
persist_dir: Directory where embeddings and Chroma DB should be stored.
Raises:
CalledProcessError: If the underlying subprocess fails.
"""
cmd = [
sys.executable,
"-m",
"src.ingest",
"--data-dir",
data_dir,
"--persist-dir",
persist_dir,
]
subprocess.run(cmd, check=True)
@st.cache_resource(show_spinner=False)
def build_or_load_retriever_cached(
data_dir: str,
persist_dir: str,
top_k: int,
retrieval_mode: str,
) -> Any:
"""Load a retriever from the persisted vectorstore or build a new one.
If loading fails—usually because the vectorstore doesn't exist—this
function triggers ingestion and retries loading.
Args:
data_dir: Directory containing input documents.
persist_dir: Directory where the Chroma vectorstore is stored.
top_k: Number of chunks to retrieve for queries.
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
Returns:
An initialized retriever instance.
"""
try:
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
except Exception:
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
@st.cache_resource(show_spinner=False)
def get_chain_cached(
model_name: str,
top_k: int,
retrieval_mode: str,
data_dir: str,
persist_dir: str,
) -> Any:
"""Create or load a cached conversational QA chain.
Args:
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval mode for the retriever.
data_dir: Path to data directory.
persist_dir: Path to vectorstore directory.
Returns:
A fully configured conversational QA chain.
"""
retriever = build_or_load_retriever_cached(
data_dir=data_dir,
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
return make_conversational_chain(retriever, model_name=model_name)