File size: 6,362 Bytes
18ef2cd
 
 
 
 
 
 
 
ee749be
 
70de36c
ee749be
70de36c
 
 
18ef2cd
 
70de36c
18ef2cd
 
70de36c
 
 
18ef2cd
 
 
 
 
 
 
 
 
 
70de36c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18ef2cd
ee749be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70de36c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee749be
 
 
 
 
 
 
 
 
 
18ef2cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70de36c
18ef2cd
 
 
 
 
 
70de36c
 
 
 
18ef2cd
 
 
70de36c
18ef2cd
 
 
70de36c
 
 
18ef2cd
 
 
70de36c
18ef2cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import sys
import subprocess
from typing import Any

import streamlit as st

from src.vectorstore import get_retriever
from src.qa_chain import make_conversational_chain
import os
import json
from typing import Dict, List, Tuple, cast

# Unconditionally import KG modules; let import errors propagate so failures are visible
from src.kg.store import KGStore
from src.kg.retriever import KGRetriever


def run_ingest_cli(data_dir: str, persist_dir: str) -> str:
    """Run the ingestion module to rebuild the vectorstore.

    Runs the ingest CLI as a subprocess and returns stdout on success.
    On failure raises subprocess.CalledProcessError with captured stdout/stderr so callers
    (for example the Streamlit UI) can display a helpful error message.
    """
    cmd = [
        sys.executable,
        "-m",
        "src.ingest",
        "--data-dir",
        data_dir,
        "--persist-dir",
        persist_dir,
    ]
    try:
        # Add a timeout to avoid indefinite hanging; 600s (10 minutes) is generous for large ingests
        completed = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
    except subprocess.TimeoutExpired as te:
        # Provide helpful error including partial output
        raise subprocess.CalledProcessError(
            returncode=124,
            cmd=cmd,
            output=getattr(te, 'output', '') or '',
            stderr=f"Ingest process timed out after {te.timeout} seconds",
        )

    # Check return code and raise with captured output on failure
    if completed.returncode != 0:
        # Raise with captured output to make it easy to present to the user
        raise subprocess.CalledProcessError(
            returncode=completed.returncode,
            cmd=cmd,
            output=completed.stdout,
            stderr=completed.stderr,
        )
    return completed.stdout


def _load_chunks_index(persist_dir: str) -> Dict[str, Dict]:
    idx_path = os.path.join(persist_dir, "chunks_index.json")
    if not os.path.exists(idx_path):
        return {}
    try:
        with open(idx_path, "r", encoding="utf-8") as fh:
            return json.load(fh)
    except Exception:
        return {}


def answer_with_kg(
        chain,
        question: str,
        chat_history: List[Tuple[str, str]],
        persist_dir: str,
        kg_hops: int = 1,
        kg_context_max_chars: int = 1000,
    ) -> Any:
    """Augment question with KG context (if available) and run the chain.

    This is a low-risk integration: we build a short textual summary from the KG
    (node labels and short chunk snippets from chunks_index.json) and prepend it to
    the question. The chain's retriever still runs; KG context is additional grounding.
    """
    kg_text_parts: List[str] = []
    # Load chunks index mapping
    chunks_index = _load_chunks_index(persist_dir)

    # Load KG unconditionally; let import or parse errors raise so callers can see them.
    kg_path = os.path.join(persist_dir, "kg_store.ttl")
    try:
        kg = KGStore(path=kg_path)
        retr = KGRetriever(kg)
        chunk_ids, summaries = retr.get_context_for_question(question, hops=kg_hops)
        if summaries:
            kg_text_parts.append("KG entities: " + ", ".join(summaries))
        # add chunk snippets
        for cid in chunk_ids:
            info = chunks_index.get(cid)
            if info:
                txt = info.get("text", "")
                if txt:
                    snippet = txt.strip().replace("\n", " ")[:min(len(txt), kg_context_max_chars)]
                    kg_text_parts.append(f"[KG chunk {cid}]: {snippet}")
    except Exception:
        # If KG load or query fails, skip KG augmentation (allow the exception to surface in logs)
        kg_text_parts = []

    kg_context = "\n\n".join(kg_text_parts) if kg_text_parts else ""
    if kg_context:
        augmented_question = f"KG CONTEXT:\n{kg_context}\n\nUser Question:\n{question}"
    else:
        augmented_question = question

    return chain({"question": augmented_question, "chat_history": chat_history})


@st.cache_resource(show_spinner=False)
def build_or_load_retriever_cached(
        data_dir: str,
        persist_dir: str,
        top_k: int,
        retrieval_mode: str,
) -> Any:
    """Load a retriever from the persisted vectorstore or build a new one.

    If loading fails—usually because the vectorstore doesn't exist—this
    function triggers ingestion and retries loading.

    Args:
        data_dir: Directory containing input documents.
        persist_dir: Directory where the Chroma vectorstore is stored.
        top_k: Number of chunks to retrieve.
        retrieval_mode: Retrieval strategy (mmr, similarity, hybrid).

    Returns:
        An initialized retriever instance.
    """
    try:
        # Cast retrieval_mode to the expected literal type to satisfy type checkers
        from typing import Literal
        RetrievalMode = Literal["mmr", "similarity", "hybrid"]
        mode = cast(RetrievalMode, retrieval_mode)
        return get_retriever(
            persist_dir=persist_dir,
            top_k=top_k,
            retrieval_mode=mode,
        )
    except Exception:
        run_ingest_cli(data_dir=data_dir, persist_dir=persist_dir)
        from typing import Literal
        RetrievalMode = Literal["mmr", "similarity", "hybrid"]
        mode = cast(RetrievalMode, retrieval_mode)
        return get_retriever(
            persist_dir=persist_dir,
            top_k=top_k,
            retrieval_mode=mode,
        )


@st.cache_resource(show_spinner=False)
def get_chain_cached(
        model_name: str,
        top_k: int,
        retrieval_mode: str,
        data_dir: str,
        persist_dir: str,
) -> Any:
    """Create or load a cached conversational QA chain.

    Args:
        model_name: The OpenAI model to use (gpt-3.5-turbo, gpt-4).
        top_k: Number of chunks to retrieve.
        retrieval_mode: Retrieval mode for the retriever.
        data_dir: Path to data directory.
        persist_dir: Path to vectorstore directory.

    Returns:
        A fully configured conversational QA chain.
    """
    retriever = build_or_load_retriever_cached(
        data_dir=data_dir,
        persist_dir=persist_dir,
        top_k=top_k,
        retrieval_mode=retrieval_mode,
    )
    return make_conversational_chain(retriever, model_name=model_name)