|
|
import sys |
|
|
import subprocess |
|
|
from typing import Any |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
from src.vectorstore import get_retriever |
|
|
from src.qa_chain import make_conversational_chain |
|
|
import os |
|
|
import json |
|
|
from typing import Dict, List, Tuple, cast |
|
|
|
|
|
|
|
|
from src.kg.store import KGStore |
|
|
from src.kg.retriever import KGRetriever |
|
|
|
|
|
|
|
|
def run_ingest_cli(data_dir: str, persist_dir: str) -> str: |
|
|
"""Run the ingestion module to rebuild the vectorstore. |
|
|
|
|
|
Runs the ingest CLI as a subprocess and returns stdout on success. |
|
|
On failure raises subprocess.CalledProcessError with captured stdout/stderr so callers |
|
|
(for example the Streamlit UI) can display a helpful error message. |
|
|
""" |
|
|
cmd = [ |
|
|
sys.executable, |
|
|
"-m", |
|
|
"src.ingest", |
|
|
"--data-dir", |
|
|
data_dir, |
|
|
"--persist-dir", |
|
|
persist_dir, |
|
|
] |
|
|
try: |
|
|
|
|
|
completed = subprocess.run(cmd, capture_output=True, text=True, timeout=600) |
|
|
except subprocess.TimeoutExpired as te: |
|
|
|
|
|
raise subprocess.CalledProcessError( |
|
|
returncode=124, |
|
|
cmd=cmd, |
|
|
output=getattr(te, 'output', '') or '', |
|
|
stderr=f"Ingest process timed out after {te.timeout} seconds", |
|
|
) |
|
|
|
|
|
|
|
|
if completed.returncode != 0: |
|
|
|
|
|
raise subprocess.CalledProcessError( |
|
|
returncode=completed.returncode, |
|
|
cmd=cmd, |
|
|
output=completed.stdout, |
|
|
stderr=completed.stderr, |
|
|
) |
|
|
return completed.stdout |
|
|
|
|
|
|
|
|
def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]: |
|
|
idx_path = os.path.join(persist_dir, "chunks_index.json") |
|
|
if not os.path.exists(idx_path): |
|
|
return {} |
|
|
try: |
|
|
with open(idx_path, "r", encoding="utf-8") as fh: |
|
|
return json.load(fh) |
|
|
except Exception: |
|
|
return {} |
|
|
|
|
|
|
|
|
def answer_with_kg( |
|
|
chain, |
|
|
question: str, |
|
|
chat_history: List[Tuple[str, str]], |
|
|
persist_dir: str, |
|
|
kg_hops: int = 1, |
|
|
kg_context_max_chars: int = 1000, |
|
|
) -> Any: |
|
|
"""Augment question with KG context (if available) and run the chain. |
|
|
|
|
|
This is a low-risk integration: we build a short textual summary from the KG |
|
|
(node labels and short chunk snippets from chunks_index.json) and prepend it to |
|
|
the question. The chain's retriever still runs; KG context is additional grounding. |
|
|
""" |
|
|
kg_text_parts: List[str] = [] |
|
|
|
|
|
chunks_index = _load_chunks_index(persist_dir) |
|
|
|
|
|
|
|
|
kg_path = os.path.join(persist_dir, "kg_store.ttl") |
|
|
try: |
|
|
kg = KGStore(path=kg_path) |
|
|
retr = KGRetriever(kg) |
|
|
chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops) |
|
|
if summaries: |
|
|
kg_text_parts.append("KG entities: " + ", ".join(summaries)) |
|
|
|
|
|
for cid in chunk_ids: |
|
|
info = chunks_index.get(cid) |
|
|
if info: |
|
|
txt = info.get("text", "") |
|
|
if txt: |
|
|
snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)] |
|
|
kg_text_parts.append(f"[KG chunk {cid}]: {snippet}") |
|
|
except Exception: |
|
|
|
|
|
kg_text_parts = [] |
|
|
|
|
|
kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else "" |
|
|
if kg_context: |
|
|
augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}" |
|
|
else: |
|
|
augmented_question = question |
|
|
|
|
|
return chain({"question": augmented_question, "chat_history": chat_history}) |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def build_or_load_retriever_cached( |
|
|
data_dir: str, |
|
|
persist_dir: str, |
|
|
top_k: int, |
|
|
retrieval_mode: str, |
|
|
) -> Any: |
|
|
"""Load a retriever from the persisted vectorstore or build a new one. |
|
|
|
|
|
If loading fails—usually because the vectorstore doesn't exist—this |
|
|
function triggers ingestion and retries loading. |
|
|
|
|
|
Args: |
|
|
data_dir: Directory containing input documents. |
|
|
persist_dir: Directory where the Chroma vectorstore is stored. |
|
|
top_k: Number of chunks to retrieve. |
|
|
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid). |
|
|
|
|
|
Returns: |
|
|
An initialized retriever instance. |
|
|
""" |
|
|
try: |
|
|
|
|
|
from typing import Literal |
|
|
RetrievalMode = Literal["mmr", "similarity", "hybrid"] |
|
|
mode = cast(RetrievalMode, retrieval_mode) |
|
|
return get_retriever( |
|
|
persist_dir=persist_dir, |
|
|
top_k=top_k, |
|
|
retrieval_mode=mode, |
|
|
) |
|
|
except Exception: |
|
|
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir) |
|
|
from typing import Literal |
|
|
RetrievalMode = Literal["mmr", "similarity", "hybrid"] |
|
|
mode = cast(RetrievalMode, retrieval_mode) |
|
|
return get_retriever( |
|
|
persist_dir=persist_dir, |
|
|
top_k=top_k, |
|
|
retrieval_mode=mode, |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def get_chain_cached( |
|
|
model_name: str, |
|
|
top_k: int, |
|
|
retrieval_mode: str, |
|
|
data_dir: str, |
|
|
persist_dir: str, |
|
|
) -> Any: |
|
|
"""Create or load a cached conversational QA chain. |
|
|
|
|
|
Args: |
|
|
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4). |
|
|
top_k: Number of chunks to retrieve. |
|
|
retrieval_mode: Retrieval mode for the retriever. |
|
|
data_dir: Path to data directory. |
|
|
persist_dir: Path to vectorstore directory. |
|
|
|
|
|
Returns: |
|
|
A fully configured conversational QA chain. |
|
|
""" |
|
|
retriever = build_or_load_retriever_cached( |
|
|
data_dir=data_dir, |
|
|
persist_dir=persist_dir, |
|
|
top_k=top_k, |
|
|
retrieval_mode=retrieval_mode, |
|
|
) |
|
|
return make_conversational_chain(retriever, model_name=model_name) |
|
|
|