cmd0160's picture
Adding kg updates
70de36c
import sys
import subprocess
from typing import Any
import streamlit as st
from src.vectorstore import get_retriever
from src.qa_chain import make_conversational_chain
import os
import json
from typing import Dict, List, Tuple, cast
# Unconditionally import KG modules; let import errors propagate so failures are visible
from src.kg.store import KGStore
from src.kg.retriever import KGRetriever
def run_ingest_cli(data_dir: str, persist_dir: str) -> str:
"""Run the ingestion module to rebuild the vectorstore.
Runs the ingest CLI as a subprocess and returns stdout on success.
On failure raises subprocess.CalledProcessError with captured stdout/stderr so callers
(for example the Streamlit UI) can display a helpful error message.
"""
cmd = [
sys.executable,
"-m",
"src.ingest",
"--data-dir",
data_dir,
"--persist-dir",
persist_dir,
]
try:
# Add a timeout to avoid indefinite hanging; 600s (10 minutes) is generous for large ingests
completed = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
except subprocess.TimeoutExpired as te:
# Provide helpful error including partial output
raise subprocess.CalledProcessError(
returncode=124,
cmd=cmd,
output=getattr(te, 'output', '') or '',
stderr=f"Ingest process timed out after {te.timeout} seconds",
)
# Check return code and raise with captured output on failure
if completed.returncode != 0:
# Raise with captured output to make it easy to present to the user
raise subprocess.CalledProcessError(
returncode=completed.returncode,
cmd=cmd,
output=completed.stdout,
stderr=completed.stderr,
)
return completed.stdout
def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]:
idx_path = os.path.join(persist_dir, "chunks_index.json")
if not os.path.exists(idx_path):
return {}
try:
with open(idx_path, "r", encoding="utf-8") as fh:
return json.load(fh)
except Exception:
return {}
def answer_with_kg(
chain,
question: str,
chat_history: List[Tuple[str, str]],
persist_dir: str,
kg_hops: int = 1,
kg_context_max_chars: int = 1000,
) -> Any:
"""Augment question with KG context (if available) and run the chain.
This is a low-risk integration: we build a short textual summary from the KG
(node labels and short chunk snippets from chunks_index.json) and prepend it to
the question. The chain's retriever still runs; KG context is additional grounding.
"""
kg_text_parts: List[str] = []
# Load chunks index mapping
chunks_index = _load_chunks_index(persist_dir)
# Load KG unconditionally; let import or parse errors raise so callers can see them.
kg_path = os.path.join(persist_dir, "kg_store.ttl")
try:
kg = KGStore(path=kg_path)
retr = KGRetriever(kg)
chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops)
if summaries:
kg_text_parts.append("KG entities: " + ", ".join(summaries))
# add chunk snippets
for cid in chunk_ids:
info = chunks_index.get(cid)
if info:
txt = info.get("text", "")
if txt:
snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)]
kg_text_parts.append(f"[KG chunk {cid}]: {snippet}")
except Exception:
# If KG load or query fails, skip KG augmentation (allow the exception to surface in logs)
kg_text_parts = []
kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else ""
if kg_context:
augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}"
else:
augmented_question = question
return chain({"question": augmented_question, "chat_history": chat_history})
@st.cache_resource(show_spinner=False)
def build_or_load_retriever_cached(
data_dir: str,
persist_dir: str,
top_k: int,
retrieval_mode: str,
) -> Any:
"""Load a retriever from the persisted vectorstore or build a new one.
If loading fails—usually because the vectorstore doesn't exist—this
function triggers ingestion and retries loading.
Args:
data_dir: Directory containing input documents.
persist_dir: Directory where the Chroma vectorstore is stored.
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
Returns:
An initialized retriever instance.
"""
try:
# Cast retrieval_mode to the expected literal type to satisfy type checkers
from typing import Literal
RetrievalMode = Literal["mmr", "similarity", "hybrid"]
mode = cast(RetrievalMode, retrieval_mode)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=mode,
)
except Exception:
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
from typing import Literal
RetrievalMode = Literal["mmr", "similarity", "hybrid"]
mode = cast(RetrievalMode, retrieval_mode)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=mode,
)
@st.cache_resource(show_spinner=False)
def get_chain_cached(
model_name: str,
top_k: int,
retrieval_mode: str,
data_dir: str,
persist_dir: str,
) -> Any:
"""Create or load a cached conversational QA chain.
Args:
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval mode for the retriever.
data_dir: Path to data directory.
persist_dir: Path to vectorstore directory.
Returns:
A fully configured conversational QA chain.
"""
retriever = build_or_load_retriever_cached(
data_dir=data_dir,
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
return make_conversational_chain(retriever, model_name=model_name)