File size: 6,362 Bytes
18ef2cd ee749be 70de36c ee749be 70de36c 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd ee749be 70de36c ee749be 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd 70de36c 18ef2cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import sys
import subprocess
from typing import Any
import streamlit as st
from src.vectorstore import get_retriever
from src.qa_chain import make_conversational_chain
import os
import json
from typing import Dict, List, Tuple, cast
# Unconditionally import KG modules; let import errors propagate so failures are visible
from src.kg.store import KGStore
from src.kg.retriever import KGRetriever
def run_ingest_cli(data_dir: str, persist_dir: str) -> str:
"""Run the ingestion module to rebuild the vectorstore.
Runs the ingest CLI as a subprocess and returns stdout on success.
On failure raises subprocess.CalledProcessError with captured stdout/stderr so callers
(for example the Streamlit UI) can display a helpful error message.
"""
cmd = [
sys.executable,
"-m",
"src.ingest",
"--data-dir",
data_dir,
"--persist-dir",
persist_dir,
]
try:
# Add a timeout to avoid indefinite hanging; 600s (10 minutes) is generous for large ingests
completed = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
except subprocess.TimeoutExpired as te:
# Provide helpful error including partial output
raise subprocess.CalledProcessError(
returncode=124,
cmd=cmd,
output=getattr(te, 'output', '') or '',
stderr=f"Ingest process timed out after {te.timeout} seconds",
)
# Check return code and raise with captured output on failure
if completed.returncode != 0:
# Raise with captured output to make it easy to present to the user
raise subprocess.CalledProcessError(
returncode=completed.returncode,
cmd=cmd,
output=completed.stdout,
stderr=completed.stderr,
)
return completed.stdout
def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]:
idx_path = os.path.join(persist_dir, "chunks_index.json")
if not os.path.exists(idx_path):
return {}
try:
with open(idx_path, "r", encoding="utf-8") as fh:
return json.load(fh)
except Exception:
return {}
def answer_with_kg(
chain,
question: str,
chat_history: List[Tuple[str, str]],
persist_dir: str,
kg_hops: int = 1,
kg_context_max_chars: int = 1000,
) -> Any:
"""Augment question with KG context (if available) and run the chain.
This is a low-risk integration: we build a short textual summary from the KG
(node labels and short chunk snippets from chunks_index.json) and prepend it to
the question. The chain's retriever still runs; KG context is additional grounding.
"""
kg_text_parts: List[str] = []
# Load chunks index mapping
chunks_index = _load_chunks_index(persist_dir)
# Load KG unconditionally; let import or parse errors raise so callers can see them.
kg_path = os.path.join(persist_dir, "kg_store.ttl")
try:
kg = KGStore(path=kg_path)
retr = KGRetriever(kg)
chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops)
if summaries:
kg_text_parts.append("KG entities: " + ", ".join(summaries))
# add chunk snippets
for cid in chunk_ids:
info = chunks_index.get(cid)
if info:
txt = info.get("text", "")
if txt:
snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)]
kg_text_parts.append(f"[KG chunk {cid}]: {snippet}")
except Exception:
# If KG load or query fails, skip KG augmentation (allow the exception to surface in logs)
kg_text_parts = []
kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else ""
if kg_context:
augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}"
else:
augmented_question = question
return chain({"question": augmented_question, "chat_history": chat_history})
@st.cache_resource(show_spinner=False)
def build_or_load_retriever_cached(
data_dir: str,
persist_dir: str,
top_k: int,
retrieval_mode: str,
) -> Any:
"""Load a retriever from the persisted vectorstore or build a new one.
If loading fails—usually because the vectorstore doesn't exist—this
function triggers ingestion and retries loading.
Args:
data_dir: Directory containing input documents.
persist_dir: Directory where the Chroma vectorstore is stored.
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
Returns:
An initialized retriever instance.
"""
try:
# Cast retrieval_mode to the expected literal type to satisfy type checkers
from typing import Literal
RetrievalMode = Literal["mmr", "similarity", "hybrid"]
mode = cast(RetrievalMode, retrieval_mode)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=mode,
)
except Exception:
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
from typing import Literal
RetrievalMode = Literal["mmr", "similarity", "hybrid"]
mode = cast(RetrievalMode, retrieval_mode)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=mode,
)
@st.cache_resource(show_spinner=False)
def get_chain_cached(
model_name: str,
top_k: int,
retrieval_mode: str,
data_dir: str,
persist_dir: str,
) -> Any:
"""Create or load a cached conversational QA chain.
Args:
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval mode for the retriever.
data_dir: Path to data directory.
persist_dir: Path to vectorstore directory.
Returns:
A fully configured conversational QA chain.
"""
retriever = build_or_load_retriever_cached(
data_dir=data_dir,
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
return make_conversational_chain(retriever, model_name=model_name)
|