import os import pathlib import time import re from pinecone import Pinecone from langchain_mistralai import ChatMistralAI from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain.schema import Document from langchain_community.document_loaders import ( CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader, UnstructuredMarkdownLoader, UnstructuredHTMLLoader, NotebookLoader ) from langchain_text_splitters import RecursiveCharacterTextSplitter from llama_index.core.memory import Memory import pickle import json from typing import List, Any from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from typing import List, Any from pydantic import BaseModel, ValidationError memory = Memory(token_limit=2048) def generate_RAG( prompt_message, llm, retrieved_chunks, graph_context="", graphRAG=False, info=True ): """ Two-stage flow (single function): 1) Resolver (non-streaming, no callbacks): decide if this turn should be history-only. Produce resolved_task. 2) Answer (streaming via the passed llm): include retrieved context only if allowed; otherwise forbid it. Message order (to favor history for follow-ups): System (first) -> (Optional) AIMessage with Retrieved Context -> History -> Human (last) """ if info: print("Generate RAG with", prompt_message, llm) # ---------- Helpers ---------- def _to_list_messages(history: Any) -> List[BaseMessage]: """Normalizes memory history: supports list[BaseMessage] or a summary string.""" if isinstance(history, list): return history if isinstance(history, str) and history.strip(): return [AIMessage(content=f"[Conversation summary]\n{history.strip()}")] return [] def _last_ai_text(msgs: List[BaseMessage]) -> str: for m in reversed(msgs): if isinstance(m, AIMessage): return m.content return "" def _safe_json_loads(raw: str) -> dict: try: return json.loads(raw) except Exception: start, end = raw.find("{"), raw.rfind("}") if start != -1 and end != -1 and end > start: return json.loads(raw[start:end+1]) raise def _make_non_streaming_resolver(llm_): """ Create a non-streaming, callback-free copy of the same LLM class for the resolver step. Works for ChatOpenAI-style classes that accept 'model' or 'model_name'. """ model_name = getattr(llm_, "model_name", getattr(llm_, "model", None)) kwargs = {} if hasattr(llm_, "temperature"): kwargs["temperature"] = getattr(llm_, "temperature") try: return llm_.__class__(model=model_name, streaming=False, callbacks=[], **kwargs) except TypeError: return llm_.__class__(model_name=model_name, streaming=False, callbacks=[], **kwargs) def _resolver(user_text: str, history_msgs: List[BaseMessage]) -> dict: resolver_llm = _make_non_streaming_resolver(llm) RESOLVER_SYS = ( "You are a controller that decides if the next answer should rely ONLY on Chat History " "(ignore Retrieved Context completely) or may use Retrieved Context.\n" "Return STRICT JSON with keys:\n" '{ "use_history_only": true|false, "resolved_task": "" }\n\n' "Rules:\n" "- Always set set use_history_only=false (especially if the query has meaningful concepts for retrieval, e.g., specific entities, topics, product names, technical terms, factual questions).\n" "- Except in rare cases, do NOT set use_history_only=true. Only do so if the query contains undefined pronouns (e.g., this, that, it, they, those, these, above, continue, previous, earlier, same...).\n" "Examples:\n" 'User: "Where in the onboarding guide do we define the trial limits?"\n' '-> { "use_history_only": false, "resolved_task": "Find where the onboarding guide defines the trial limits and report the exact limits." }\n' ) resolver_msgs: List[BaseMessage] = [SystemMessage(RESOLVER_SYS)] last_ai = _last_ai_text(history_msgs) if last_ai: resolver_msgs.append(AIMessage(content=f"[Last assistant answer]\n{last_ai}")) resolver_msgs.extend(history_msgs) resolver_msgs.append(HumanMessage(content=f"User message: {user_text}")) raw = resolver_llm.invoke(resolver_msgs).content try: data = _safe_json_loads(raw) except Exception: data = {"use_history_only": False, "resolved_task": user_text} data.setdefault("use_history_only", False) data.setdefault("resolved_task", user_text) return data # ---------- Prepare history ---------- history_messages: List[BaseMessage] = [] if memory: # Get the last messages from LlamaIndex memory last_msgs = memory.get_all()[-8:] # Convert LlamaIndex messages to LangChain message types for msg in last_msgs: if msg.role == "user": history_messages.append(HumanMessage(content=msg.content)) elif msg.role in ("ai", "assistant"): history_messages.append(AIMessage(content=msg.content)) # Add more roles if needed # ---------- Stage 1: Resolve (non-streaming) ---------- plan = _resolver(prompt_message, history_messages) use_history_only = bool(plan.get("use_history_only", False)) resolved_task = plan.get("resolved_task", prompt_message) if info: print("[Resolver]", plan) # ---------- Build retrieval context block ---------- context_lines = [] if not use_history_only: for i, chunk in enumerate(retrieved_chunks or []): source_filename = os.path.basename((chunk.get("source") or "unknown")) text = chunk.get("text") or "" context_lines.append(f"Source {i+1} ({source_filename}):\n{text}") if graphRAG and graph_context: context_lines.append("[Graph context]\n" + graph_context) context_for_llm = "\n\n".join(context_lines) # ---------- System prompt (first) ---------- base_rules = ( "You are an expert assistant. Answer in English. Use:\n" "- Chat History\n" "- Retrieved Context (reference-only facts; not user intent).\n\n" "Decision rubric before answering:\n" "- Important: you MUST ALWAYS cite a source, i.e., always use exactly the filename from the 'source' metadata (e.g., 'Source: sample.pdf.' in the same paragraph as the claim).\n" "- If the answer is not supported by Retrieved Context and not implied by history, say you cannot answer.\n\n" "Important: output should be very well-structured Markdown (always different headings, hierarchical structure, bullets, tables and code blocks when needed), with a few emojis for scannability." ) turn_rule = ( "\n\nTURN-SPECIFIC RULE: For THIS turn, you MUST NOT use any Retrieved Context. " "Base your answer ONLY on Chat History and the user's current request." if use_history_only else "" ) prompt_parts: List[BaseMessage] = [SystemMessage(content=base_rules + turn_rule)] # ---------- Retrieved context as assistant message (only if allowed) ---------- if (not use_history_only) and context_for_llm.strip(): prompt_parts.append( SystemMessage( content="📚 Retrieved Context (reference-only; not user intent, Use info only from here and nothing else, if info not present, say you do not know. You are only allowed to base your answer on this info and not use your own):\n\n" + context_for_llm ) ) # ---------- History next (more recent than retrieval context) ---------- if history_messages: prompt_parts.append(SystemMessage(content="🕘 Chat History (most recent last):")) prompt_parts.extend(history_messages) # ---------- Current user last (include BOTH original and resolved) ---------- final_human = ( "User request (original):\n" f"{prompt_message}\n\n" "Resolved task (use this when pronouns/references appear):\n" f"{resolved_task}" ) prompt_parts.append(HumanMessage(content=final_human)) # ---------- Stage 2: Answer (streaming via passed llm) ---------- print(f"[Info] The final prompt is the following: {prompt_parts}") response = llm.invoke(prompt_parts) print(f"[Info] The final response is the following: {response}") # ---------- Pydantic validation: ensure some "Source:" structure is present ---------- class _AnswerWithCitationStructure(BaseModel): content: str @classmethod def ensure_source_structure(cls, content: str): """ Check that there is at least one 'Source:' or 'Sources:' pattern in the text. """ import re if not re.search(r"\bSources?:\s*.+", content, flags=re.IGNORECASE): raise ValueError("Missing any 'Source:' structure in the answer.") # Run validation only when we expected citations (retrieval was allowed) try: if not use_history_only: _AnswerWithCitationStructure.ensure_source_structure( getattr(response, "content", str(response)) ) except (ValidationError, ValueError) as ve: print(f"[Validation] Source structure check failed: {ve}") # Retry answer generation with stronger emphasis on sources retry_prompt_parts = prompt_parts.copy() retry_prompt_parts.append(SystemMessage( content="⚠️ IMPORTANT: Your previous answer did not include any 'Source:' citation. " "Regenerate your answer and make sure to include at least one 'Source: ...' or 'Sources: ...' line " "that cites the relevant documents or context." )) response = llm.invoke(retry_prompt_parts) print("[Retry] Regenerated answer with source emphasis.") # ---------- Persist to memory ---------- from llama_index.core.llms import ChatMessage # ---------- Persist to memory ---------- if memory: # Add user message memory.put(ChatMessage(role="user", content=prompt_message)) if not use_history_only: # Add context as AI message memory.put(ChatMessage(role="assistant", content=f"The context was: [start context] {context_for_llm} [end context]")) # Add final AI response memory.put(ChatMessage(role="assistant", content=getattr(response, "content", str(response)))) # To print the current memory, retrieve all messages print("[Info] The following is the current memory:", memory.get_all()) return response