""" rag_chain.py ============ Phase 6 – Full RAG Application Chain Assembles retrieval, prompt templating, and LLM generation into a single LCEL (LangChain Expression Language) chain using open-source LLMs. LLM selection (automatic, no API key required) ----------------------------------------------- get_llm() tries in order: 1. Ollama — if a chat model is available locally Pull one first: ollama pull llama3.2:1b 2. HuggingFace — google/flan-t5-large (seq2seq, ~3 GB, runs on CPU) Architecture ------------ query ──► FinancialRetriever ──► build_context() ──► PromptTemplate │ Ollama / HuggingFacePipeline │ StrOutputParser │ final answer (str) Two chain variants ------------------ build_rag_chain() Simple chain: retriever → prompt → LLM → string Input : {"query": str} Output: str build_rag_chain_with_sources() Returns answer AND source chunks for citation transparency. Input : {"query": str} Output: {"answer": str, "source_documents": list[Document], "context": str} Conversation support -------------------- build_conversational_chain() Wraps the RAG chain with sliding-window memory (last k turns). Input : {"query": str} Output: str Usage ----- from src.rag_chain import build_rag_chain, get_llm chain = build_rag_chain() answer = chain.invoke({"query": "What was Apple's revenue in FY2024?"}) """ import os import logging import subprocess from pathlib import Path from dotenv import load_dotenv load_dotenv() # ── Logging ──────────────────────────────────────────────────────────────────── logging.basicConfig( level = logging.INFO, format = "%(asctime)s %(levelname)-8s %(message)s", ) log = logging.getLogger(__name__) # ── Paths ────────────────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent.parent VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore" # ── Defaults ─────────────────────────────────────────────────────────────────── DEFAULT_OLLAMA_MODEL = "llama3.2:1b" # small, fast; pull with: ollama pull llama3.2:1b DEFAULT_HF_MODEL = "google/flan-t5-large" DEFAULT_N_RESULTS = 5 DEFAULT_MAX_CHARS = 6000 MEMORY_WINDOW_K = 5 # ══════════════════════════════════════════════════════════════════════════════ # PROMPT # ══════════════════════════════════════════════════════════════════════════════ # Single PromptTemplate works for both Ollama and HuggingFace (completion-style) RAG_PROMPT_TEMPLATE = """\ You are a financial analyst assistant with access to Morningstar research \ reports and Apple SEC filings. Instructions: - Answer ONLY from the context below. Never invent facts. - Cite every claim with [1], [2], etc. matching the source headers. - Reproduce financial figures exactly — do not round or paraphrase. - If the answer is not in the context, say: "The provided documents do not \ contain enough information to answer this question." Context: {context} Question: {query} Answer:""" CONVERSATIONAL_PROMPT_TEMPLATE = """\ You are a financial analyst assistant with access to Morningstar research \ reports and Apple SEC filings. Instructions: - Answer ONLY from the context below. Never invent facts. - Cite every claim with [1], [2], etc. matching the source headers. - Reproduce financial figures exactly. - Use conversation history to resolve follow-up questions. - If the answer is not in the context, say: "The provided documents do not \ contain enough information to answer this question." Context: {context} Conversation history: {history} Question: {query} Answer:""" # ══════════════════════════════════════════════════════════════════════════════ # LLM FACTORY # ══════════════════════════════════════════════════════════════════════════════ def get_llm(model_name: str = None): """ Return a LangChain LLM using the best available backend. Priority: 1. Google Gemini (if GOOGLE_API_KEY env var is set) Free tier: 1M tokens/min, 1,500 req/day — get key at aistudio.google.com 2. Ollama (if a local chat model is available) Setup: ollama pull llama3.2:1b 3. HuggingFace (google/flan-t5-large, fallback, CPU-only) Args: model_name : override default model name for whichever backend is chosen Returns: LangChain LLM """ # ── 1. Google Gemini (cloud, free tier) ──────────────────────────────────── gemini_key = os.getenv("GOOGLE_API_KEY") if gemini_key: try: from langchain_google_genai import ChatGoogleGenerativeAI model = model_name or "gemini-2.5-flash" log.info(f"LLM backend: Google Gemini model={model}") return ChatGoogleGenerativeAI( model = model, temperature = 0, google_api_key = gemini_key, ) except Exception as e: log.warning(f"Gemini unavailable ({e}) — falling back to Ollama") # ── 2. Ollama (local) ────────────────────────────────────────────────────── try: result = subprocess.run( ["ollama", "list"], capture_output=True, text=True, timeout=5 ) chat_models = [ line.split()[0] for line in result.stdout.strip().splitlines()[1:] if line.strip() and "embed" not in line.lower() ] if chat_models: chosen = model_name or chat_models[0] from langchain_community.llms import Ollama log.info(f"LLM backend: Ollama model={chosen}") return Ollama(model=chosen) except Exception: pass # ── 3. HuggingFace (CPU fallback) ───────────────────────────────────────── hf_model = model_name or DEFAULT_HF_MODEL log.info(f"LLM backend: HuggingFace model={hf_model}") from transformers import pipeline as hf_pipeline from langchain_huggingface import HuggingFacePipeline pipe = hf_pipeline( "text2text-generation", model = hf_model, max_new_tokens = 512, ) return HuggingFacePipeline(pipeline=pipe) def llm_info(llm) -> str: """Human-readable description of the LLM in use.""" name = type(llm).__name__ if hasattr(llm, "model"): return f"{name}({llm.model})" if hasattr(llm, "pipeline"): return f"{name}({llm.pipeline.model.name_or_path})" return name # ══════════════════════════════════════════════════════════════════════════════ # SIMPLE RAG CHAIN # ══════════════════════════════════════════════════════════════════════════════ def build_rag_chain( vectorstore_dir : Path = VECTORSTORE_DIR, rerank : bool = True, n_results : int = DEFAULT_N_RESULTS, max_chars : int = DEFAULT_MAX_CHARS, filters : dict = None, llm = None, model_name : str = None, ): """ Build a simple RAG chain: retrieval → context → LLM → string. Args: vectorstore_dir : path to ChromaDB persistent storage rerank : enable cross-encoder reranking n_results : number of chunks to retrieve max_chars : max context length passed to LLM filters : optional ChromaDB where filter llm : pre-built LLM (skips get_llm() if provided) model_name : model to pass to get_llm() if llm not provided Returns: LCEL Runnable — invoke with {"query": str} Output: str """ from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from src.retriever import FinancialRetriever retriever = FinancialRetriever( vectorstore_dir = vectorstore_dir, rerank = rerank, ) _llm = llm or get_llm(model_name) prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) parser = StrOutputParser() log.info(f"RAG chain LLM: {llm_info(_llm)}") def retrieve_and_build_context(inputs: dict) -> dict: query = inputs["query"] f = inputs.get("filters", filters) chunks = retriever.retrieve(query, n_results=n_results, filters=f) context = retriever.build_context(chunks, max_chars=max_chars) log.info(f"Retrieved {len(chunks)} chunks for: {query[:60]!r}") return {"query": query, "context": context} chain = ( RunnableLambda(retrieve_and_build_context) | prompt | _llm | parser ) return chain # ══════════════════════════════════════════════════════════════════════════════ # RAG CHAIN WITH SOURCES # ══════════════════════════════════════════════════════════════════════════════ def build_rag_chain_with_sources( vectorstore_dir : Path = VECTORSTORE_DIR, rerank : bool = True, n_results : int = DEFAULT_N_RESULTS, max_chars : int = DEFAULT_MAX_CHARS, filters : dict = None, llm = None, model_name : str = None, ): """ RAG chain that returns answer + source documents. Returns: LCEL Runnable — invoke with {"query": str} Output: { "answer" : str, "source_documents" : list[Document], "context" : str, } """ from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from langchain_core.documents import Document from src.retriever import FinancialRetriever retriever = FinancialRetriever( vectorstore_dir = vectorstore_dir, rerank = rerank, ) _llm = llm or get_llm(model_name) prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) parser = StrOutputParser() answer_chain = prompt | _llm | parser def run(inputs: dict) -> dict: query = inputs["query"] f = inputs.get("filters", filters) chunks = retriever.retrieve(query, n_results=n_results, filters=f) context = retriever.build_context(chunks, max_chars=max_chars) source_docs = [ Document( page_content = c["text"], metadata = {**c["metadata"], "score": c["score"]}, ) for c in chunks ] answer = answer_chain.invoke({"query": query, "context": context}) return { "answer" : answer, "source_documents" : source_docs, "context" : context, } return RunnableLambda(run) # ══════════════════════════════════════════════════════════════════════════════ # CONVERSATIONAL RAG CHAIN # ══════════════════════════════════════════════════════════════════════════════ def build_conversational_chain( vectorstore_dir : Path = VECTORSTORE_DIR, rerank : bool = True, n_results : int = DEFAULT_N_RESULTS, max_chars : int = DEFAULT_MAX_CHARS, filters : dict = None, llm = None, model_name : str = None, memory_k : int = MEMORY_WINDOW_K, ): """ Conversational RAG chain with sliding-window memory. Returns ConversationalRAGChain with: .invoke({"query": str}) → str .clear_history() → resets conversation .history → list of (question, answer) tuples """ from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from src.retriever import FinancialRetriever retriever = FinancialRetriever( vectorstore_dir = vectorstore_dir, rerank = rerank, ) _llm = llm or get_llm(model_name) prompt = PromptTemplate.from_template(CONVERSATIONAL_PROMPT_TEMPLATE) parser = StrOutputParser() return ConversationalRAGChain( retriever = retriever, llm = _llm, prompt = prompt, parser = parser, n_results = n_results, max_chars = max_chars, filters = filters, memory_k = memory_k, ) class ConversationalRAGChain: """Stateful RAG chain remembering the last `memory_k` turns.""" def __init__(self, retriever, llm, prompt, parser, n_results, max_chars, filters, memory_k): self._retriever = retriever self._llm = llm self._prompt = prompt self._parser = parser self._n_results = n_results self._max_chars = max_chars self._filters = filters self._memory_k = memory_k self._history : list[tuple[str, str]] = [] @property def history(self) -> list[tuple[str, str]]: return list(self._history) def clear_history(self): self._history = [] log.info("Conversation history cleared.") def _format_history(self) -> str: if not self._history: return "(no prior conversation)" lines = [] for q, a in self._history[-self._memory_k:]: lines += [f"Human: {q}", f"Assistant: {a}"] return "\n".join(lines) def invoke(self, inputs: dict) -> str: query = inputs["query"] f = inputs.get("filters", self._filters) chunks = self._retriever.retrieve(query, n_results=self._n_results, filters=f) context = self._retriever.build_context(chunks, max_chars=self._max_chars) chain = self._prompt | self._llm | self._parser answer = chain.invoke({ "query" : query, "context": context, "history": self._format_history(), }) self._history.append((query, answer)) return answer