File size: 11,089 Bytes
4be6b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | import os
import pathlib
import time
import re
from pinecone import Pinecone
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.memory import Memory
import pickle
import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from typing import List, Any
from pydantic import BaseModel, ValidationError
memory = Memory(token_limit=2048)
def generate_RAG(
prompt_message,
llm,
retrieved_chunks,
graph_context="",
graphRAG=False,
info=True
):
"""
Two-stage flow (single function):
1) Resolver (non-streaming, no callbacks): decide if this turn should be history-only. Produce resolved_task.
2) Answer (streaming via the passed llm): include retrieved context only if allowed; otherwise forbid it.
Message order (to favor history for follow-ups):
System (first) -> (Optional) AIMessage with Retrieved Context -> History -> Human (last)
"""
if info:
print("Generate RAG with", prompt_message, llm)
# ---------- Helpers ----------
def _to_list_messages(history: Any) -> List[BaseMessage]:
"""Normalizes memory history: supports list[BaseMessage] or a summary string."""
if isinstance(history, list):
return history
if isinstance(history, str) and history.strip():
return [AIMessage(content=f"[Conversation summary]\n{history.strip()}")]
return []
def _last_ai_text(msgs: List[BaseMessage]) -> str:
for m in reversed(msgs):
if isinstance(m, AIMessage):
return m.content
return ""
def _safe_json_loads(raw: str) -> dict:
try:
return json.loads(raw)
except Exception:
start, end = raw.find("{"), raw.rfind("}")
if start != -1 and end != -1 and end > start:
return json.loads(raw[start:end+1])
raise
def _make_non_streaming_resolver(llm_):
"""
Create a non-streaming, callback-free copy of the same LLM class for the resolver step.
Works for ChatOpenAI-style classes that accept 'model' or 'model_name'.
"""
model_name = getattr(llm_, "model_name", getattr(llm_, "model", None))
kwargs = {}
if hasattr(llm_, "temperature"):
kwargs["temperature"] = getattr(llm_, "temperature")
try:
return llm_.__class__(model=model_name, streaming=False, callbacks=[], **kwargs)
except TypeError:
return llm_.__class__(model_name=model_name, streaming=False, callbacks=[], **kwargs)
def _resolver(user_text: str, history_msgs: List[BaseMessage]) -> dict:
resolver_llm = _make_non_streaming_resolver(llm)
RESOLVER_SYS = (
"You are a controller that decides if the next answer should rely ONLY on Chat History "
"(ignore Retrieved Context completely) or may use Retrieved Context.\n"
"Return STRICT JSON with keys:\n"
'{ "use_history_only": true|false, "resolved_task": "<resolved user request>" }\n\n'
"Rules:\n"
"- Always set set use_history_only=false (especially if the query has meaningful concepts for retrieval, e.g., specific entities, topics, product names, technical terms, factual questions).\n"
"- Except in rare cases, do NOT set use_history_only=true. Only do so if the query contains undefined pronouns (e.g., this, that, it, they, those, these, above, continue, previous, earlier, same...).\n"
"Examples:\n"
'User: "Where in the onboarding guide do we define the trial limits?"\n'
'-> { "use_history_only": false, "resolved_task": "Find where the onboarding guide defines the trial limits and report the exact limits." }\n'
)
resolver_msgs: List[BaseMessage] = [SystemMessage(RESOLVER_SYS)]
last_ai = _last_ai_text(history_msgs)
if last_ai:
resolver_msgs.append(AIMessage(content=f"[Last assistant answer]\n{last_ai}"))
resolver_msgs.extend(history_msgs)
resolver_msgs.append(HumanMessage(content=f"User message: {user_text}"))
raw = resolver_llm.invoke(resolver_msgs).content
try:
data = _safe_json_loads(raw)
except Exception:
data = {"use_history_only": False, "resolved_task": user_text}
data.setdefault("use_history_only", False)
data.setdefault("resolved_task", user_text)
return data
# ---------- Prepare history ----------
history_messages: List[BaseMessage] = []
if memory:
# Get the last messages from LlamaIndex memory
last_msgs = memory.get_all()[-8:]
# Convert LlamaIndex messages to LangChain message types
for msg in last_msgs:
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role in ("ai", "assistant"):
history_messages.append(AIMessage(content=msg.content))
# Add more roles if needed
# ---------- Stage 1: Resolve (non-streaming) ----------
plan = _resolver(prompt_message, history_messages)
use_history_only = bool(plan.get("use_history_only", False))
resolved_task = plan.get("resolved_task", prompt_message)
if info:
print("[Resolver]", plan)
# ---------- Build retrieval context block ----------
context_lines = []
if not use_history_only:
for i, chunk in enumerate(retrieved_chunks or []):
source_filename = os.path.basename((chunk.get("source") or "unknown"))
text = chunk.get("text") or ""
context_lines.append(f"Source {i+1} ({source_filename}):\n{text}")
if graphRAG and graph_context:
context_lines.append("[Graph context]\n" + graph_context)
context_for_llm = "\n\n".join(context_lines)
# ---------- System prompt (first) ----------
base_rules = (
"You are an expert assistant. Answer in English. Use:\n"
"- Chat History\n"
"- Retrieved Context (reference-only facts; not user intent).\n\n"
"Decision rubric before answering:\n"
"- Important: you MUST ALWAYS cite a source, i.e., always use exactly the filename from the 'source' metadata (e.g., 'Source: sample.pdf.' in the same paragraph as the claim).\n"
"- If the answer is not supported by Retrieved Context and not implied by history, say you cannot answer.\n\n"
"Important: output should be very well-structured Markdown (always different headings, hierarchical structure, bullets, tables and code blocks when needed), with a few emojis for scannability."
)
turn_rule = (
"\n\nTURN-SPECIFIC RULE: For THIS turn, you MUST NOT use any Retrieved Context. "
"Base your answer ONLY on Chat History and the user's current request."
if use_history_only else ""
)
prompt_parts: List[BaseMessage] = [SystemMessage(content=base_rules + turn_rule)]
# ---------- Retrieved context as assistant message (only if allowed) ----------
if (not use_history_only) and context_for_llm.strip():
prompt_parts.append(
SystemMessage(
content="📚 Retrieved Context (reference-only; not user intent, Use info only from here and nothing else, if info not present, say you do not know. You are only allowed to base your answer on this info and not use your own):\n\n" + context_for_llm
)
)
# ---------- History next (more recent than retrieval context) ----------
if history_messages:
prompt_parts.append(SystemMessage(content="🕘 Chat History (most recent last):"))
prompt_parts.extend(history_messages)
# ---------- Current user last (include BOTH original and resolved) ----------
final_human = (
"User request (original):\n"
f"{prompt_message}\n\n"
"Resolved task (use this when pronouns/references appear):\n"
f"{resolved_task}"
)
prompt_parts.append(HumanMessage(content=final_human))
# ---------- Stage 2: Answer (streaming via passed llm) ----------
print(f"[Info] The final prompt is the following: {prompt_parts}")
response = llm.invoke(prompt_parts)
print(f"[Info] The final response is the following: {response}")
# ---------- Pydantic validation: ensure some "Source:" structure is present ----------
class _AnswerWithCitationStructure(BaseModel):
content: str
@classmethod
def ensure_source_structure(cls, content: str):
"""
Check that there is at least one 'Source:' or 'Sources:' pattern in the text.
"""
import re
if not re.search(r"\bSources?:\s*.+", content, flags=re.IGNORECASE):
raise ValueError("Missing any 'Source:' structure in the answer.")
# Run validation only when we expected citations (retrieval was allowed)
try:
if not use_history_only:
_AnswerWithCitationStructure.ensure_source_structure(
getattr(response, "content", str(response))
)
except (ValidationError, ValueError) as ve:
print(f"[Validation] Source structure check failed: {ve}")
# Retry answer generation with stronger emphasis on sources
retry_prompt_parts = prompt_parts.copy()
retry_prompt_parts.append(SystemMessage(
content="⚠️ IMPORTANT: Your previous answer did not include any 'Source:' citation. "
"Regenerate your answer and make sure to include at least one 'Source: ...' or 'Sources: ...' line "
"that cites the relevant documents or context."
))
response = llm.invoke(retry_prompt_parts)
print("[Retry] Regenerated answer with source emphasis.")
# ---------- Persist to memory ----------
from llama_index.core.llms import ChatMessage
# ---------- Persist to memory ----------
if memory:
# Add user message
memory.put(ChatMessage(role="user", content=prompt_message))
if not use_history_only:
# Add context as AI message
memory.put(ChatMessage(role="assistant", content=f"The context was: [start context] {context_for_llm} [end context]"))
# Add final AI response
memory.put(ChatMessage(role="assistant", content=getattr(response, "content", str(response))))
# To print the current memory, retrieve all messages
print("[Info] The following is the current memory:", memory.get_all())
return response
|