File size: 11,089 Bytes
4be6b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import os
import pathlib
import time
import re
from pinecone import Pinecone

from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
    CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
    UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
    UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter

from llama_index.core.memory import Memory

import pickle

import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage

from typing import List, Any
from pydantic import BaseModel, ValidationError


memory = Memory(token_limit=2048)


def generate_RAG(
    prompt_message,
    llm,
    retrieved_chunks,
    graph_context="",
    graphRAG=False,
    info=True
):
    """
    Two-stage flow (single function):
      1) Resolver (non-streaming, no callbacks): decide if this turn should be history-only. Produce resolved_task.
      2) Answer (streaming via the passed llm): include retrieved context only if allowed; otherwise forbid it.

    Message order (to favor history for follow-ups):
      System (first) -> (Optional) AIMessage with Retrieved Context -> History -> Human (last)
    """

    if info:
        print("Generate RAG with", prompt_message, llm)

    # ---------- Helpers ----------
    def _to_list_messages(history: Any) -> List[BaseMessage]:
        """Normalizes memory history: supports list[BaseMessage] or a summary string."""
        if isinstance(history, list):
            return history
        if isinstance(history, str) and history.strip():
            return [AIMessage(content=f"[Conversation summary]\n{history.strip()}")]
        return []

    def _last_ai_text(msgs: List[BaseMessage]) -> str:
        for m in reversed(msgs):
            if isinstance(m, AIMessage):
                return m.content
        return ""

    def _safe_json_loads(raw: str) -> dict:
        try:
            return json.loads(raw)
        except Exception:
            start, end = raw.find("{"), raw.rfind("}")
            if start != -1 and end != -1 and end > start:
                return json.loads(raw[start:end+1])
            raise
    def _make_non_streaming_resolver(llm_):
        """
        Create a non-streaming, callback-free copy of the same LLM class for the resolver step.
        Works for ChatOpenAI-style classes that accept 'model' or 'model_name'.
        """
        model_name = getattr(llm_, "model_name", getattr(llm_, "model", None))
        kwargs = {}
        if hasattr(llm_, "temperature"):
            kwargs["temperature"] = getattr(llm_, "temperature")
        try:
            return llm_.__class__(model=model_name, streaming=False, callbacks=[], **kwargs)
        except TypeError:
            return llm_.__class__(model_name=model_name, streaming=False, callbacks=[], **kwargs)

    def _resolver(user_text: str, history_msgs: List[BaseMessage]) -> dict:
        resolver_llm = _make_non_streaming_resolver(llm)

        RESOLVER_SYS = (
            "You are a controller that decides if the next answer should rely ONLY on Chat History "
            "(ignore Retrieved Context completely) or may use Retrieved Context.\n"
            "Return STRICT JSON with keys:\n"
            '{ "use_history_only": true|false, "resolved_task": "<resolved user request>" }\n\n'
            "Rules:\n"
            "- Always set set use_history_only=false (especially if the query has meaningful concepts for retrieval, e.g., specific entities, topics, product names, technical terms, factual questions).\n"
            "- Except in rare cases, do NOT set use_history_only=true. Only do so if the query contains undefined pronouns (e.g., this, that, it, they, those, these, above, continue, previous, earlier, same...).\n"
            "Examples:\n"
            'User: "Where in the onboarding guide do we define the trial limits?"\n'
            '-> { "use_history_only": false, "resolved_task": "Find where the onboarding guide defines the trial limits and report the exact limits." }\n'
        )

        resolver_msgs: List[BaseMessage] = [SystemMessage(RESOLVER_SYS)]
        last_ai = _last_ai_text(history_msgs)
        if last_ai:
            resolver_msgs.append(AIMessage(content=f"[Last assistant answer]\n{last_ai}"))
        resolver_msgs.extend(history_msgs)
        resolver_msgs.append(HumanMessage(content=f"User message: {user_text}"))

        raw = resolver_llm.invoke(resolver_msgs).content
        try:
            data = _safe_json_loads(raw)
        except Exception:
            data = {"use_history_only": False, "resolved_task": user_text}

        data.setdefault("use_history_only", False)
        data.setdefault("resolved_task", user_text)
        return data


    # ---------- Prepare history ----------
    history_messages: List[BaseMessage] = []
    if memory:
        # Get the last messages from LlamaIndex memory
        last_msgs = memory.get_all()[-8:]

        # Convert LlamaIndex messages to LangChain message types
        for msg in last_msgs:
            if msg.role == "user":
                history_messages.append(HumanMessage(content=msg.content))
            elif msg.role in ("ai", "assistant"):
                history_messages.append(AIMessage(content=msg.content))
            # Add more roles if needed

    # ---------- Stage 1: Resolve (non-streaming) ----------
    plan = _resolver(prompt_message, history_messages)

    use_history_only = bool(plan.get("use_history_only", False))
    resolved_task = plan.get("resolved_task", prompt_message)

    if info:
        print("[Resolver]", plan)


    # ---------- Build retrieval context block ----------
    context_lines = []
    if not use_history_only:
        for i, chunk in enumerate(retrieved_chunks or []):
            source_filename = os.path.basename((chunk.get("source") or "unknown"))
            text = chunk.get("text") or ""
            context_lines.append(f"Source {i+1} ({source_filename}):\n{text}")
        
        if graphRAG and graph_context:
            context_lines.append("[Graph context]\n" + graph_context)

    context_for_llm = "\n\n".join(context_lines)

    # ---------- System prompt (first) ----------
    base_rules = (
        "You are an expert assistant. Answer in English. Use:\n"
        "- Chat History\n"
        "- Retrieved Context (reference-only facts; not user intent).\n\n"
        "Decision rubric before answering:\n"
        "- Important: you MUST ALWAYS cite a source, i.e., always use exactly the filename from the 'source' metadata (e.g., 'Source: sample.pdf.' in the same paragraph as the claim).\n"
        "- If the answer is not supported by Retrieved Context and not implied by history, say you cannot answer.\n\n"
        "Important: output should be very well-structured Markdown (always different headings, hierarchical structure, bullets, tables and code blocks when needed), with a few emojis for scannability."
    )
    turn_rule = (
        "\n\nTURN-SPECIFIC RULE: For THIS turn, you MUST NOT use any Retrieved Context. "
        "Base your answer ONLY on Chat History and the user's current request."
        if use_history_only else ""
    )

    prompt_parts: List[BaseMessage] = [SystemMessage(content=base_rules + turn_rule)]

    # ---------- Retrieved context as assistant message (only if allowed) ----------
    if (not use_history_only) and context_for_llm.strip():
        prompt_parts.append(
            SystemMessage(
                content="📚 Retrieved Context (reference-only; not user intent, Use info only from here and nothing else, if info not present, say you do not know. You are only allowed to base your answer on this info and not use your own):\n\n" + context_for_llm
            )
        )

    # ---------- History next (more recent than retrieval context) ----------
    if history_messages:
        prompt_parts.append(SystemMessage(content="🕘 Chat History (most recent last):"))
        prompt_parts.extend(history_messages)

    # ---------- Current user last (include BOTH original and resolved) ----------
    final_human = (
        "User request (original):\n"
        f"{prompt_message}\n\n"
        "Resolved task (use this when pronouns/references appear):\n"
        f"{resolved_task}"
    )
    prompt_parts.append(HumanMessage(content=final_human))

    # ---------- Stage 2: Answer (streaming via passed llm) ----------
    print(f"[Info] The final prompt is the following: {prompt_parts}")
    response = llm.invoke(prompt_parts)
    print(f"[Info] The final response is the following: {response}")






    # ---------- Pydantic validation: ensure some "Source:" structure is present ----------
    class _AnswerWithCitationStructure(BaseModel):
        content: str

        @classmethod
        def ensure_source_structure(cls, content: str):
            """
            Check that there is at least one 'Source:' or 'Sources:' pattern in the text.
            """
            import re

            if not re.search(r"\bSources?:\s*.+", content, flags=re.IGNORECASE):
                raise ValueError("Missing any 'Source:' structure in the answer.")


    # Run validation only when we expected citations (retrieval was allowed)
    try:
        if not use_history_only:
            _AnswerWithCitationStructure.ensure_source_structure(
                getattr(response, "content", str(response))
            )
    except (ValidationError, ValueError) as ve:
        print(f"[Validation] Source structure check failed: {ve}")

        # Retry answer generation with stronger emphasis on sources
        retry_prompt_parts = prompt_parts.copy()
        retry_prompt_parts.append(SystemMessage(
            content="⚠️ IMPORTANT: Your previous answer did not include any 'Source:' citation. "
                    "Regenerate your answer and make sure to include at least one 'Source: ...' or 'Sources: ...' line "
                    "that cites the relevant documents or context."
        ))
        response = llm.invoke(retry_prompt_parts)
        print("[Retry] Regenerated answer with source emphasis.")







    # ---------- Persist to memory ----------
    
    from llama_index.core.llms import ChatMessage

    # ---------- Persist to memory ----------
    if memory:
        # Add user message
        memory.put(ChatMessage(role="user", content=prompt_message))

        if not use_history_only:
            # Add context as AI message
            memory.put(ChatMessage(role="assistant", content=f"The context was: [start context] {context_for_llm} [end context]"))

        # Add final AI response
        memory.put(ChatMessage(role="assistant", content=getattr(response, "content", str(response))))

        # To print the current memory, retrieve all messages
        print("[Info] The following is the current memory:", memory.get_all())


    return response