Spaces:
Running
Running
| """ | |
| rag_chain.py | |
| ============ | |
| Phase 6 β Full RAG Application Chain | |
| Assembles retrieval, prompt templating, and LLM generation into a single | |
| LCEL (LangChain Expression Language) chain using open-source LLMs. | |
| LLM selection (automatic, no API key required) | |
| ----------------------------------------------- | |
| get_llm() tries in order: | |
| 1. Ollama β if a chat model is available locally | |
| Pull one first: ollama pull llama3.2:1b | |
| 2. HuggingFace β google/flan-t5-large (seq2seq, ~3 GB, runs on CPU) | |
| Architecture | |
| ------------ | |
| query βββΊ FinancialRetriever βββΊ build_context() βββΊ PromptTemplate | |
| β | |
| Ollama / HuggingFacePipeline | |
| β | |
| StrOutputParser | |
| β | |
| final answer (str) | |
| Two chain variants | |
| ------------------ | |
| build_rag_chain() | |
| Simple chain: retriever β prompt β LLM β string | |
| Input : {"query": str} | |
| Output: str | |
| build_rag_chain_with_sources() | |
| Returns answer AND source chunks for citation transparency. | |
| Input : {"query": str} | |
| Output: {"answer": str, "source_documents": list[Document], "context": str} | |
| Conversation support | |
| -------------------- | |
| build_conversational_chain() | |
| Wraps the RAG chain with sliding-window memory (last k turns). | |
| Input : {"query": str} | |
| Output: str | |
| Usage | |
| ----- | |
| from src.rag_chain import build_rag_chain, get_llm | |
| chain = build_rag_chain() | |
| answer = chain.invoke({"query": "What was Apple's revenue in FY2024?"}) | |
| """ | |
| import os | |
| import logging | |
| import subprocess | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level = logging.INFO, | |
| format = "%(asctime)s %(levelname)-8s %(message)s", | |
| ) | |
| log = logging.getLogger(__name__) | |
| # ββ Paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).parent.parent | |
| VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore" | |
| # ββ Defaults βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_OLLAMA_MODEL = "llama3.2:1b" # small, fast; pull with: ollama pull llama3.2:1b | |
| DEFAULT_HF_MODEL = "google/flan-t5-large" | |
| DEFAULT_N_RESULTS = 5 | |
| DEFAULT_MAX_CHARS = 6000 | |
| MEMORY_WINDOW_K = 5 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PROMPT | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Single PromptTemplate works for both Ollama and HuggingFace (completion-style) | |
| RAG_PROMPT_TEMPLATE = """\ | |
| You are a financial analyst assistant with access to Morningstar research \ | |
| reports and Apple SEC filings. | |
| Instructions: | |
| - Answer ONLY from the context below. Never invent facts. | |
| - Cite every claim with [1], [2], etc. matching the source headers. | |
| - Reproduce financial figures exactly β do not round or paraphrase. | |
| - If the answer is not in the context, say: "The provided documents do not \ | |
| contain enough information to answer this question." | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| CONVERSATIONAL_PROMPT_TEMPLATE = """\ | |
| You are a financial analyst assistant with access to Morningstar research \ | |
| reports and Apple SEC filings. | |
| Instructions: | |
| - Answer ONLY from the context below. Never invent facts. | |
| - Cite every claim with [1], [2], etc. matching the source headers. | |
| - Reproduce financial figures exactly. | |
| - Use conversation history to resolve follow-up questions. | |
| - If the answer is not in the context, say: "The provided documents do not \ | |
| contain enough information to answer this question." | |
| Context: | |
| {context} | |
| Conversation history: | |
| {history} | |
| Question: {query} | |
| Answer:""" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LLM FACTORY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_llm(model_name: str = None): | |
| """ | |
| Return a LangChain LLM using the best available backend. | |
| Priority: | |
| 1. Google Gemini (if GOOGLE_API_KEY env var is set) | |
| Free tier: 1M tokens/min, 1,500 req/day β get key at aistudio.google.com | |
| 2. Ollama (if a local chat model is available) | |
| Setup: ollama pull llama3.2:1b | |
| 3. HuggingFace (google/flan-t5-large, fallback, CPU-only) | |
| Args: | |
| model_name : override default model name for whichever backend is chosen | |
| Returns: | |
| LangChain LLM | |
| """ | |
| # ββ 1. Google Gemini (cloud, free tier) ββββββββββββββββββββββββββββββββββββ | |
| gemini_key = os.getenv("GOOGLE_API_KEY") | |
| if gemini_key: | |
| try: | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| model = model_name or "gemini-2.5-flash" | |
| log.info(f"LLM backend: Google Gemini model={model}") | |
| return ChatGoogleGenerativeAI( | |
| model = model, | |
| temperature = 0, | |
| google_api_key = gemini_key, | |
| ) | |
| except Exception as e: | |
| log.warning(f"Gemini unavailable ({e}) β falling back to Ollama") | |
| # ββ 2. Ollama (local) ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| result = subprocess.run( | |
| ["ollama", "list"], | |
| capture_output=True, text=True, timeout=5 | |
| ) | |
| chat_models = [ | |
| line.split()[0] | |
| for line in result.stdout.strip().splitlines()[1:] | |
| if line.strip() and "embed" not in line.lower() | |
| ] | |
| if chat_models: | |
| chosen = model_name or chat_models[0] | |
| from langchain_community.llms import Ollama | |
| log.info(f"LLM backend: Ollama model={chosen}") | |
| return Ollama(model=chosen) | |
| except Exception: | |
| pass | |
| # ββ 3. HuggingFace (CPU fallback) βββββββββββββββββββββββββββββββββββββββββ | |
| hf_model = model_name or DEFAULT_HF_MODEL | |
| log.info(f"LLM backend: HuggingFace model={hf_model}") | |
| from transformers import pipeline as hf_pipeline | |
| from langchain_huggingface import HuggingFacePipeline | |
| pipe = hf_pipeline( | |
| "text2text-generation", | |
| model = hf_model, | |
| max_new_tokens = 512, | |
| ) | |
| return HuggingFacePipeline(pipeline=pipe) | |
| def llm_info(llm) -> str: | |
| """Human-readable description of the LLM in use.""" | |
| name = type(llm).__name__ | |
| if hasattr(llm, "model"): | |
| return f"{name}({llm.model})" | |
| if hasattr(llm, "pipeline"): | |
| return f"{name}({llm.pipeline.model.name_or_path})" | |
| return name | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SIMPLE RAG CHAIN | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_rag_chain( | |
| vectorstore_dir : Path = VECTORSTORE_DIR, | |
| rerank : bool = True, | |
| n_results : int = DEFAULT_N_RESULTS, | |
| max_chars : int = DEFAULT_MAX_CHARS, | |
| filters : dict = None, | |
| llm = None, | |
| model_name : str = None, | |
| ): | |
| """ | |
| Build a simple RAG chain: retrieval β context β LLM β string. | |
| Args: | |
| vectorstore_dir : path to ChromaDB persistent storage | |
| rerank : enable cross-encoder reranking | |
| n_results : number of chunks to retrieve | |
| max_chars : max context length passed to LLM | |
| filters : optional ChromaDB where filter | |
| llm : pre-built LLM (skips get_llm() if provided) | |
| model_name : model to pass to get_llm() if llm not provided | |
| Returns: | |
| LCEL Runnable β invoke with {"query": str} | |
| Output: str | |
| """ | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from src.retriever import FinancialRetriever | |
| retriever = FinancialRetriever( | |
| vectorstore_dir = vectorstore_dir, | |
| rerank = rerank, | |
| ) | |
| _llm = llm or get_llm(model_name) | |
| prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) | |
| parser = StrOutputParser() | |
| log.info(f"RAG chain LLM: {llm_info(_llm)}") | |
| def retrieve_and_build_context(inputs: dict) -> dict: | |
| query = inputs["query"] | |
| f = inputs.get("filters", filters) | |
| chunks = retriever.retrieve(query, n_results=n_results, filters=f) | |
| context = retriever.build_context(chunks, max_chars=max_chars) | |
| log.info(f"Retrieved {len(chunks)} chunks for: {query[:60]!r}") | |
| return {"query": query, "context": context} | |
| chain = ( | |
| RunnableLambda(retrieve_and_build_context) | |
| | prompt | |
| | _llm | |
| | parser | |
| ) | |
| return chain | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RAG CHAIN WITH SOURCES | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_rag_chain_with_sources( | |
| vectorstore_dir : Path = VECTORSTORE_DIR, | |
| rerank : bool = True, | |
| n_results : int = DEFAULT_N_RESULTS, | |
| max_chars : int = DEFAULT_MAX_CHARS, | |
| filters : dict = None, | |
| llm = None, | |
| model_name : str = None, | |
| ): | |
| """ | |
| RAG chain that returns answer + source documents. | |
| Returns: | |
| LCEL Runnable β invoke with {"query": str} | |
| Output: { | |
| "answer" : str, | |
| "source_documents" : list[Document], | |
| "context" : str, | |
| } | |
| """ | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_core.documents import Document | |
| from src.retriever import FinancialRetriever | |
| retriever = FinancialRetriever( | |
| vectorstore_dir = vectorstore_dir, | |
| rerank = rerank, | |
| ) | |
| _llm = llm or get_llm(model_name) | |
| prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE) | |
| parser = StrOutputParser() | |
| answer_chain = prompt | _llm | parser | |
| def run(inputs: dict) -> dict: | |
| query = inputs["query"] | |
| f = inputs.get("filters", filters) | |
| chunks = retriever.retrieve(query, n_results=n_results, filters=f) | |
| context = retriever.build_context(chunks, max_chars=max_chars) | |
| source_docs = [ | |
| Document( | |
| page_content = c["text"], | |
| metadata = {**c["metadata"], "score": c["score"]}, | |
| ) | |
| for c in chunks | |
| ] | |
| answer = answer_chain.invoke({"query": query, "context": context}) | |
| return { | |
| "answer" : answer, | |
| "source_documents" : source_docs, | |
| "context" : context, | |
| } | |
| return RunnableLambda(run) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONVERSATIONAL RAG CHAIN | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_conversational_chain( | |
| vectorstore_dir : Path = VECTORSTORE_DIR, | |
| rerank : bool = True, | |
| n_results : int = DEFAULT_N_RESULTS, | |
| max_chars : int = DEFAULT_MAX_CHARS, | |
| filters : dict = None, | |
| llm = None, | |
| model_name : str = None, | |
| memory_k : int = MEMORY_WINDOW_K, | |
| ): | |
| """ | |
| Conversational RAG chain with sliding-window memory. | |
| Returns ConversationalRAGChain with: | |
| .invoke({"query": str}) β str | |
| .clear_history() β resets conversation | |
| .history β list of (question, answer) tuples | |
| """ | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from src.retriever import FinancialRetriever | |
| retriever = FinancialRetriever( | |
| vectorstore_dir = vectorstore_dir, | |
| rerank = rerank, | |
| ) | |
| _llm = llm or get_llm(model_name) | |
| prompt = PromptTemplate.from_template(CONVERSATIONAL_PROMPT_TEMPLATE) | |
| parser = StrOutputParser() | |
| return ConversationalRAGChain( | |
| retriever = retriever, | |
| llm = _llm, | |
| prompt = prompt, | |
| parser = parser, | |
| n_results = n_results, | |
| max_chars = max_chars, | |
| filters = filters, | |
| memory_k = memory_k, | |
| ) | |
| class ConversationalRAGChain: | |
| """Stateful RAG chain remembering the last `memory_k` turns.""" | |
| def __init__(self, retriever, llm, prompt, parser, n_results, | |
| max_chars, filters, memory_k): | |
| self._retriever = retriever | |
| self._llm = llm | |
| self._prompt = prompt | |
| self._parser = parser | |
| self._n_results = n_results | |
| self._max_chars = max_chars | |
| self._filters = filters | |
| self._memory_k = memory_k | |
| self._history : list[tuple[str, str]] = [] | |
| def history(self) -> list[tuple[str, str]]: | |
| return list(self._history) | |
| def clear_history(self): | |
| self._history = [] | |
| log.info("Conversation history cleared.") | |
| def _format_history(self) -> str: | |
| if not self._history: | |
| return "(no prior conversation)" | |
| lines = [] | |
| for q, a in self._history[-self._memory_k:]: | |
| lines += [f"Human: {q}", f"Assistant: {a}"] | |
| return "\n".join(lines) | |
| def invoke(self, inputs: dict) -> str: | |
| query = inputs["query"] | |
| f = inputs.get("filters", self._filters) | |
| chunks = self._retriever.retrieve(query, n_results=self._n_results, filters=f) | |
| context = self._retriever.build_context(chunks, max_chars=self._max_chars) | |
| chain = self._prompt | self._llm | self._parser | |
| answer = chain.invoke({ | |
| "query" : query, | |
| "context": context, | |
| "history": self._format_history(), | |
| }) | |
| self._history.append((query, answer)) | |
| return answer | |