Spaces:
Sleeping
Sleeping
File size: 5,038 Bytes
18ef2cd ee749be 18ef2cd 0438c70 18ef2cd ee749be 18ef2cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import sys
import subprocess
from typing import Any
import streamlit as st
from src.vectorstore import get_retriever
from src.qa_chain import make_conversational_chain
import os
import json
from typing import Dict, List, Tuple
try:
from src.kg.store import KGStore
from src.kg.retriever import KGRetriever
_HAS_KG = True
except Exception:
_HAS_KG = False
def run_ingest_cli(data_dir: str, persist_dir: str) -> None:
"""Run the ingestion module to rebuild the vectorstore.
Args:
data_dir: Directory containing the raw text files.
persist_dir: Directory where embeddings and Chroma DB should be stored.
Raises:
CalledProcessError: If the underlying subprocess fails.
"""
cmd = [
sys.executable,
"-m",
"src.ingest",
"--data-dir",
data_dir,
"--persist-dir",
persist_dir,
]
subprocess.run(cmd, check=True)
def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]:
idx_path = os.path.join(persist_dir, "chunks_index.json")
if not os.path.exists(idx_path):
return {}
try:
with open(idx_path, "r", encoding="utf-8") as fh:
return json.load(fh)
except Exception:
return {}
def answer_with_kg(
chain,
question: str,
chat_history: List[Tuple[str, str]],
persist_dir: str,
kg_hops: int = 1,
kg_context_max_chars: int = 1000,
) -> Any:
"""Augment question with KG context (if available) and run the chain.
This is a low-risk integration: we build a short textual summary from the KG
(node labels and short chunk snippets from chunks_index.json) and prepend it to
the question. The chain's retriever still runs; KG context is additional grounding.
"""
kg_text_parts: List[str] = []
# Load chunks index mapping
chunks_index = _load_chunks_index(persist_dir)
if _HAS_KG:
kg_path = os.path.join(persist_dir, "kg_store.ttl")
try:
kg = KGStore(path=kg_path)
retr = KGRetriever(kg)
chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops)
if summaries:
kg_text_parts.append("KG entities: " + ", ".join(summaries))
# add chunk snippets
for cid in chunk_ids:
info = chunks_index.get(cid)
if info:
txt = info.get("text", "")
if txt:
snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)]
kg_text_parts.append(f"[KG chunk {cid}]: {snippet}")
except Exception:
# If KG load fails, skip KG augmentation
kg_text_parts = []
kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else ""
if kg_context:
augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}"
else:
augmented_question = question
return chain({"question": augmented_question, "chat_history": chat_history})
@st.cache_resource(show_spinner=False)
def build_or_load_retriever_cached(
data_dir: str,
persist_dir: str,
top_k: int,
retrieval_mode: str,
) -> Any:
"""Load a retriever from the persisted vectorstore or build a new one.
If loading fails—usually because the vectorstore doesn't exist—this
function triggers ingestion and retries loading.
Args:
data_dir: Directory containing input documents.
persist_dir: Directory where the Chroma vectorstore is stored.
top_k: Number of chunks to retrieve for queries.
retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).
Returns:
An initialized retriever instance.
"""
try:
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
except Exception:
run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
return get_retriever(
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
@st.cache_resource(show_spinner=False)
def get_chain_cached(
model_name: str,
top_k: int,
retrieval_mode: str,
data_dir: str,
persist_dir: str,
) -> Any:
"""Create or load a cached conversational QA chain.
Args:
model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
top_k: Number of chunks to retrieve.
retrieval_mode: Retrieval mode for the retriever.
data_dir: Path to data directory.
persist_dir: Path to vectorstore directory.
Returns:
A fully configured conversational QA chain.
"""
retriever = build_or_load_retriever_cached(
data_dir=data_dir,
persist_dir=persist_dir,
top_k=top_k,
retrieval_mode=retrieval_mode,
)
return make_conversational_chain(retriever, model_name=model_name)
|