Spaces:
Sleeping
Sleeping
| import sys | |
| import subprocess | |
| from typing import Any | |
| import streamlit as st | |
| from src.vectorstore import get_retriever | |
| from src.qa_chain import make_conversational_chain | |
| import os | |
| import json | |
| from typing import Dict, List, Tuple | |
| try: | |
| from src.kg.store import KGStore | |
| from src.kg.retriever import KGRetriever | |
| _HAS_KG = True | |
| except Exception: | |
| _HAS_KG = False | |
| def run_ingest_cli(data_dir: str, persist_dir: str) -> None: | |
| """Run the ingestion module to rebuild the vectorstore. | |
| Args: | |
| data_dir: Directory containing the raw text files. | |
| persist_dir: Directory where embeddings and Chroma DB should be stored. | |
| Raises: | |
| CalledProcessError: If the underlying subprocess fails. | |
| """ | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "src.ingest", | |
| "--data-dir", | |
| data_dir, | |
| "--persist-dir", | |
| persist_dir, | |
| ] | |
| subprocess.run(cmd, check=True) | |
| def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]: | |
| idx_path = os.path.join(persist_dir, "chunks_index.json") | |
| if not os.path.exists(idx_path): | |
| return {} | |
| try: | |
| with open(idx_path, "r", encoding="utf-8") as fh: | |
| return json.load(fh) | |
| except Exception: | |
| return {} | |
| def answer_with_kg( | |
| chain, | |
| question: str, | |
| chat_history: List[Tuple[str, str]], | |
| persist_dir: str, | |
| kg_hops: int = 1, | |
| kg_context_max_chars: int = 1000, | |
| ) -> Any: | |
| """Augment question with KG context (if available) and run the chain. | |
| This is a low-risk integration: we build a short textual summary from the KG | |
| (node labels and short chunk snippets from chunks_index.json) and prepend it to | |
| the question. The chain's retriever still runs; KG context is additional grounding. | |
| """ | |
| kg_text_parts: List[str] = [] | |
| # Load chunks index mapping | |
| chunks_index = _load_chunks_index(persist_dir) | |
| if _HAS_KG: | |
| kg_path = os.path.join(persist_dir, "kg_store.ttl") | |
| try: | |
| kg = KGStore(path=kg_path) | |
| retr = KGRetriever(kg) | |
| chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops) | |
| if summaries: | |
| kg_text_parts.append("KG entities: " + ", ".join(summaries)) | |
| # add chunk snippets | |
| for cid in chunk_ids: | |
| info = chunks_index.get(cid) | |
| if info: | |
| txt = info.get("text", "") | |
| if txt: | |
| snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)] | |
| kg_text_parts.append(f"[KG chunk {cid}]: {snippet}") | |
| except Exception: | |
| # If KG load fails, skip KG augmentation | |
| kg_text_parts = [] | |
| kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else "" | |
| if kg_context: | |
| augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}" | |
| else: | |
| augmented_question = question | |
| return chain({"question": augmented_question, "chat_history": chat_history}) | |
| def build_or_load_retriever_cached( | |
| data_dir: str, | |
| persist_dir: str, | |
| top_k: int, | |
| retrieval_mode: str, | |
| ) -> Any: | |
| """Load a retriever from the persisted vectorstore or build a new one. | |
| If loading fails—usually because the vectorstore doesn't exist—this | |
| function triggers ingestion and retries loading. | |
| Args: | |
| data_dir: Directory containing input documents. | |
| persist_dir: Directory where the Chroma vectorstore is stored. | |
| top_k: Number of chunks to retrieve for queries. | |
| retrieval_mode: Retrieval strategy (mmr, similarity, hybrid). | |
| Returns: | |
| An initialized retriever instance. | |
| """ | |
| try: | |
| return get_retriever( | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=retrieval_mode, | |
| ) | |
| except Exception: | |
| run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir) | |
| return get_retriever( | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=retrieval_mode, | |
| ) | |
| def get_chain_cached( | |
| model_name: str, | |
| top_k: int, | |
| retrieval_mode: str, | |
| data_dir: str, | |
| persist_dir: str, | |
| ) -> Any: | |
| """Create or load a cached conversational QA chain. | |
| Args: | |
| model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4). | |
| top_k: Number of chunks to retrieve. | |
| retrieval_mode: Retrieval mode for the retriever. | |
| data_dir: Path to data directory. | |
| persist_dir: Path to vectorstore directory. | |
| Returns: | |
| A fully configured conversational QA chain. | |
| """ | |
| retriever = build_or_load_retriever_cached( | |
| data_dir=data_dir, | |
| persist_dir=persist_dir, | |
| top_k=top_k, | |
| retrieval_mode=retrieval_mode, | |
| ) | |
| return make_conversational_chain(retriever, model_name=model_name) | |