Nadezhda Komarova
first commit
4be6b01
import os
import pathlib
import time
import re
from pinecone import Pinecone
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.memory import Memory
import pickle
import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from typing import List, Any
from pydantic import BaseModel, ValidationError
memory = Memory(token_limit=2048)
def generate_RAG(
prompt_message,
llm,
retrieved_chunks,
graph_context="",
graphRAG=False,
info=True
):
"""
Two-stage flow (single function):
1) Resolver (non-streaming, no callbacks): decide if this turn should be history-only. Produce resolved_task.
2) Answer (streaming via the passed llm): include retrieved context only if allowed; otherwise forbid it.
Message order (to favor history for follow-ups):
System (first) -> (Optional) AIMessage with Retrieved Context -> History -> Human (last)
"""
if info:
print("Generate RAG with", prompt_message, llm)
# ---------- Helpers ----------
def _to_list_messages(history: Any) -> List[BaseMessage]:
"""Normalizes memory history: supports list[BaseMessage] or a summary string."""
if isinstance(history, list):
return history
if isinstance(history, str) and history.strip():
return [AIMessage(content=f"[Conversation summary]\n{history.strip()}")]
return []
def _last_ai_text(msgs: List[BaseMessage]) -> str:
for m in reversed(msgs):
if isinstance(m, AIMessage):
return m.content
return ""
def _safe_json_loads(raw: str) -> dict:
try:
return json.loads(raw)
except Exception:
start, end = raw.find("{"), raw.rfind("}")
if start != -1 and end != -1 and end > start:
return json.loads(raw[start:end+1])
raise
def _make_non_streaming_resolver(llm_):
"""
Create a non-streaming, callback-free copy of the same LLM class for the resolver step.
Works for ChatOpenAI-style classes that accept 'model' or 'model_name'.
"""
model_name = getattr(llm_, "model_name", getattr(llm_, "model", None))
kwargs = {}
if hasattr(llm_, "temperature"):
kwargs["temperature"] = getattr(llm_, "temperature")
try:
return llm_.__class__(model=model_name, streaming=False, callbacks=[], **kwargs)
except TypeError:
return llm_.__class__(model_name=model_name, streaming=False, callbacks=[], **kwargs)
def _resolver(user_text: str, history_msgs: List[BaseMessage]) -> dict:
resolver_llm = _make_non_streaming_resolver(llm)
RESOLVER_SYS = (
"You are a controller that decides if the next answer should rely ONLY on Chat History "
"(ignore Retrieved Context completely) or may use Retrieved Context.\n"
"Return STRICT JSON with keys:\n"
'{ "use_history_only": true|false, "resolved_task": "<resolved user request>" }\n\n'
"Rules:\n"
"- Always set set use_history_only=false (especially if the query has meaningful concepts for retrieval, e.g., specific entities, topics, product names, technical terms, factual questions).\n"
"- Except in rare cases, do NOT set use_history_only=true. Only do so if the query contains undefined pronouns (e.g., this, that, it, they, those, these, above, continue, previous, earlier, same...).\n"
"Examples:\n"
'User: "Where in the onboarding guide do we define the trial limits?"\n'
'-> { "use_history_only": false, "resolved_task": "Find where the onboarding guide defines the trial limits and report the exact limits." }\n'
)
resolver_msgs: List[BaseMessage] = [SystemMessage(RESOLVER_SYS)]
last_ai = _last_ai_text(history_msgs)
if last_ai:
resolver_msgs.append(AIMessage(content=f"[Last assistant answer]\n{last_ai}"))
resolver_msgs.extend(history_msgs)
resolver_msgs.append(HumanMessage(content=f"User message: {user_text}"))
raw = resolver_llm.invoke(resolver_msgs).content
try:
data = _safe_json_loads(raw)
except Exception:
data = {"use_history_only": False, "resolved_task": user_text}
data.setdefault("use_history_only", False)
data.setdefault("resolved_task", user_text)
return data
# ---------- Prepare history ----------
history_messages: List[BaseMessage] = []
if memory:
# Get the last messages from LlamaIndex memory
last_msgs = memory.get_all()[-8:]
# Convert LlamaIndex messages to LangChain message types
for msg in last_msgs:
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role in ("ai", "assistant"):
history_messages.append(AIMessage(content=msg.content))
# Add more roles if needed
# ---------- Stage 1: Resolve (non-streaming) ----------
plan = _resolver(prompt_message, history_messages)
use_history_only = bool(plan.get("use_history_only", False))
resolved_task = plan.get("resolved_task", prompt_message)
if info:
print("[Resolver]", plan)
# ---------- Build retrieval context block ----------
context_lines = []
if not use_history_only:
for i, chunk in enumerate(retrieved_chunks or []):
source_filename = os.path.basename((chunk.get("source") or "unknown"))
text = chunk.get("text") or ""
context_lines.append(f"Source {i+1} ({source_filename}):\n{text}")
if graphRAG and graph_context:
context_lines.append("[Graph context]\n" + graph_context)
context_for_llm = "\n\n".join(context_lines)
# ---------- System prompt (first) ----------
base_rules = (
"You are an expert assistant. Answer in English. Use:\n"
"- Chat History\n"
"- Retrieved Context (reference-only facts; not user intent).\n\n"
"Decision rubric before answering:\n"
"- Important: you MUST ALWAYS cite a source, i.e., always use exactly the filename from the 'source' metadata (e.g., 'Source: sample.pdf.' in the same paragraph as the claim).\n"
"- If the answer is not supported by Retrieved Context and not implied by history, say you cannot answer.\n\n"
"Important: output should be very well-structured Markdown (always different headings, hierarchical structure, bullets, tables and code blocks when needed), with a few emojis for scannability."
)
turn_rule = (
"\n\nTURN-SPECIFIC RULE: For THIS turn, you MUST NOT use any Retrieved Context. "
"Base your answer ONLY on Chat History and the user's current request."
if use_history_only else ""
)
prompt_parts: List[BaseMessage] = [SystemMessage(content=base_rules + turn_rule)]
# ---------- Retrieved context as assistant message (only if allowed) ----------
if (not use_history_only) and context_for_llm.strip():
prompt_parts.append(
SystemMessage(
content="📚 Retrieved Context (reference-only; not user intent, Use info only from here and nothing else, if info not present, say you do not know. You are only allowed to base your answer on this info and not use your own):\n\n" + context_for_llm
)
)
# ---------- History next (more recent than retrieval context) ----------
if history_messages:
prompt_parts.append(SystemMessage(content="🕘 Chat History (most recent last):"))
prompt_parts.extend(history_messages)
# ---------- Current user last (include BOTH original and resolved) ----------
final_human = (
"User request (original):\n"
f"{prompt_message}\n\n"
"Resolved task (use this when pronouns/references appear):\n"
f"{resolved_task}"
)
prompt_parts.append(HumanMessage(content=final_human))
# ---------- Stage 2: Answer (streaming via passed llm) ----------
print(f"[Info] The final prompt is the following: {prompt_parts}")
response = llm.invoke(prompt_parts)
print(f"[Info] The final response is the following: {response}")
# ---------- Pydantic validation: ensure some "Source:" structure is present ----------
class _AnswerWithCitationStructure(BaseModel):
content: str
@classmethod
def ensure_source_structure(cls, content: str):
"""
Check that there is at least one 'Source:' or 'Sources:' pattern in the text.
"""
import re
if not re.search(r"\bSources?:\s*.+", content, flags=re.IGNORECASE):
raise ValueError("Missing any 'Source:' structure in the answer.")
# Run validation only when we expected citations (retrieval was allowed)
try:
if not use_history_only:
_AnswerWithCitationStructure.ensure_source_structure(
getattr(response, "content", str(response))
)
except (ValidationError, ValueError) as ve:
print(f"[Validation] Source structure check failed: {ve}")
# Retry answer generation with stronger emphasis on sources
retry_prompt_parts = prompt_parts.copy()
retry_prompt_parts.append(SystemMessage(
content="⚠️ IMPORTANT: Your previous answer did not include any 'Source:' citation. "
"Regenerate your answer and make sure to include at least one 'Source: ...' or 'Sources: ...' line "
"that cites the relevant documents or context."
))
response = llm.invoke(retry_prompt_parts)
print("[Retry] Regenerated answer with source emphasis.")
# ---------- Persist to memory ----------
from llama_index.core.llms import ChatMessage
# ---------- Persist to memory ----------
if memory:
# Add user message
memory.put(ChatMessage(role="user", content=prompt_message))
if not use_history_only:
# Add context as AI message
memory.put(ChatMessage(role="assistant", content=f"The context was: [start context] {context_for_llm} [end context]"))
# Add final AI response
memory.put(ChatMessage(role="assistant", content=getattr(response, "content", str(response))))
# To print the current memory, retrieve all messages
print("[Info] The following is the current memory:", memory.get_all())
return response