Financial_bot / src /rag_chain.py
Pushkya's picture
Upload 30 files
8299003 verified
Raw
History Blame Contribute Delete
16.5 kB
"""
rag_chain.py
============
Phase 6 – Full RAG Application Chain
Assembles retrieval, prompt templating, and LLM generation into a single
LCEL (LangChain Expression Language) chain using open-source LLMs.
LLM selection (automatic, no API key required)
-----------------------------------------------
get_llm() tries in order:
1. Ollama β€” if a chat model is available locally
Pull one first: ollama pull llama3.2:1b
2. HuggingFace β€” google/flan-t5-large (seq2seq, ~3 GB, runs on CPU)
Architecture
------------
query ──► FinancialRetriever ──► build_context() ──► PromptTemplate
β”‚
Ollama / HuggingFacePipeline
β”‚
StrOutputParser
β”‚
final answer (str)
Two chain variants
------------------
build_rag_chain()
Simple chain: retriever β†’ prompt β†’ LLM β†’ string
Input : {"query": str}
Output: str
build_rag_chain_with_sources()
Returns answer AND source chunks for citation transparency.
Input : {"query": str}
Output: {"answer": str, "source_documents": list[Document], "context": str}
Conversation support
--------------------
build_conversational_chain()
Wraps the RAG chain with sliding-window memory (last k turns).
Input : {"query": str}
Output: str
Usage
-----
from src.rag_chain import build_rag_chain, get_llm
chain = build_rag_chain()
answer = chain.invoke({"query": "What was Apple's revenue in FY2024?"})
"""
import os
import logging
import subprocess
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# ── Logging ────────────────────────────────────────────────────────────────────
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)-8s %(message)s",
)
log = logging.getLogger(__name__)
# ── Paths ──────────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent.parent
VECTORSTORE_DIR = BASE_DIR / "data" / "vectorstore"
# ── Defaults ───────────────────────────────────────────────────────────────────
DEFAULT_OLLAMA_MODEL = "llama3.2:1b" # small, fast; pull with: ollama pull llama3.2:1b
DEFAULT_HF_MODEL = "google/flan-t5-large"
DEFAULT_N_RESULTS = 5
DEFAULT_MAX_CHARS = 6000
MEMORY_WINDOW_K = 5
# ══════════════════════════════════════════════════════════════════════════════
# PROMPT
# ══════════════════════════════════════════════════════════════════════════════
# Single PromptTemplate works for both Ollama and HuggingFace (completion-style)
RAG_PROMPT_TEMPLATE = """\
You are a financial analyst assistant with access to Morningstar research \
reports and Apple SEC filings.
Instructions:
- Answer ONLY from the context below. Never invent facts.
- Cite every claim with [1], [2], etc. matching the source headers.
- Reproduce financial figures exactly β€” do not round or paraphrase.
- If the answer is not in the context, say: "The provided documents do not \
contain enough information to answer this question."
Context:
{context}
Question: {query}
Answer:"""
CONVERSATIONAL_PROMPT_TEMPLATE = """\
You are a financial analyst assistant with access to Morningstar research \
reports and Apple SEC filings.
Instructions:
- Answer ONLY from the context below. Never invent facts.
- Cite every claim with [1], [2], etc. matching the source headers.
- Reproduce financial figures exactly.
- Use conversation history to resolve follow-up questions.
- If the answer is not in the context, say: "The provided documents do not \
contain enough information to answer this question."
Context:
{context}
Conversation history:
{history}
Question: {query}
Answer:"""
# ══════════════════════════════════════════════════════════════════════════════
# LLM FACTORY
# ══════════════════════════════════════════════════════════════════════════════
def get_llm(model_name: str = None):
"""
Return a LangChain LLM using the best available backend.
Priority:
1. Google Gemini (if GOOGLE_API_KEY env var is set)
Free tier: 1M tokens/min, 1,500 req/day β€” get key at aistudio.google.com
2. Ollama (if a local chat model is available)
Setup: ollama pull llama3.2:1b
3. HuggingFace (google/flan-t5-large, fallback, CPU-only)
Args:
model_name : override default model name for whichever backend is chosen
Returns:
LangChain LLM
"""
# ── 1. Google Gemini (cloud, free tier) ────────────────────────────────────
gemini_key = os.getenv("GOOGLE_API_KEY")
if gemini_key:
try:
from langchain_google_genai import ChatGoogleGenerativeAI
model = model_name or "gemini-2.5-flash"
log.info(f"LLM backend: Google Gemini model={model}")
return ChatGoogleGenerativeAI(
model = model,
temperature = 0,
google_api_key = gemini_key,
)
except Exception as e:
log.warning(f"Gemini unavailable ({e}) β€” falling back to Ollama")
# ── 2. Ollama (local) ──────────────────────────────────────────────────────
try:
result = subprocess.run(
["ollama", "list"],
capture_output=True, text=True, timeout=5
)
chat_models = [
line.split()[0]
for line in result.stdout.strip().splitlines()[1:]
if line.strip() and "embed" not in line.lower()
]
if chat_models:
chosen = model_name or chat_models[0]
from langchain_community.llms import Ollama
log.info(f"LLM backend: Ollama model={chosen}")
return Ollama(model=chosen)
except Exception:
pass
# ── 3. HuggingFace (CPU fallback) ─────────────────────────────────────────
hf_model = model_name or DEFAULT_HF_MODEL
log.info(f"LLM backend: HuggingFace model={hf_model}")
from transformers import pipeline as hf_pipeline
from langchain_huggingface import HuggingFacePipeline
pipe = hf_pipeline(
"text2text-generation",
model = hf_model,
max_new_tokens = 512,
)
return HuggingFacePipeline(pipeline=pipe)
def llm_info(llm) -> str:
"""Human-readable description of the LLM in use."""
name = type(llm).__name__
if hasattr(llm, "model"):
return f"{name}({llm.model})"
if hasattr(llm, "pipeline"):
return f"{name}({llm.pipeline.model.name_or_path})"
return name
# ══════════════════════════════════════════════════════════════════════════════
# SIMPLE RAG CHAIN
# ══════════════════════════════════════════════════════════════════════════════
def build_rag_chain(
vectorstore_dir : Path = VECTORSTORE_DIR,
rerank : bool = True,
n_results : int = DEFAULT_N_RESULTS,
max_chars : int = DEFAULT_MAX_CHARS,
filters : dict = None,
llm = None,
model_name : str = None,
):
"""
Build a simple RAG chain: retrieval β†’ context β†’ LLM β†’ string.
Args:
vectorstore_dir : path to ChromaDB persistent storage
rerank : enable cross-encoder reranking
n_results : number of chunks to retrieve
max_chars : max context length passed to LLM
filters : optional ChromaDB where filter
llm : pre-built LLM (skips get_llm() if provided)
model_name : model to pass to get_llm() if llm not provided
Returns:
LCEL Runnable β€” invoke with {"query": str}
Output: str
"""
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from src.retriever import FinancialRetriever
retriever = FinancialRetriever(
vectorstore_dir = vectorstore_dir,
rerank = rerank,
)
_llm = llm or get_llm(model_name)
prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
parser = StrOutputParser()
log.info(f"RAG chain LLM: {llm_info(_llm)}")
def retrieve_and_build_context(inputs: dict) -> dict:
query = inputs["query"]
f = inputs.get("filters", filters)
chunks = retriever.retrieve(query, n_results=n_results, filters=f)
context = retriever.build_context(chunks, max_chars=max_chars)
log.info(f"Retrieved {len(chunks)} chunks for: {query[:60]!r}")
return {"query": query, "context": context}
chain = (
RunnableLambda(retrieve_and_build_context)
| prompt
| _llm
| parser
)
return chain
# ══════════════════════════════════════════════════════════════════════════════
# RAG CHAIN WITH SOURCES
# ══════════════════════════════════════════════════════════════════════════════
def build_rag_chain_with_sources(
vectorstore_dir : Path = VECTORSTORE_DIR,
rerank : bool = True,
n_results : int = DEFAULT_N_RESULTS,
max_chars : int = DEFAULT_MAX_CHARS,
filters : dict = None,
llm = None,
model_name : str = None,
):
"""
RAG chain that returns answer + source documents.
Returns:
LCEL Runnable β€” invoke with {"query": str}
Output: {
"answer" : str,
"source_documents" : list[Document],
"context" : str,
}
"""
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.documents import Document
from src.retriever import FinancialRetriever
retriever = FinancialRetriever(
vectorstore_dir = vectorstore_dir,
rerank = rerank,
)
_llm = llm or get_llm(model_name)
prompt = PromptTemplate.from_template(RAG_PROMPT_TEMPLATE)
parser = StrOutputParser()
answer_chain = prompt | _llm | parser
def run(inputs: dict) -> dict:
query = inputs["query"]
f = inputs.get("filters", filters)
chunks = retriever.retrieve(query, n_results=n_results, filters=f)
context = retriever.build_context(chunks, max_chars=max_chars)
source_docs = [
Document(
page_content = c["text"],
metadata = {**c["metadata"], "score": c["score"]},
)
for c in chunks
]
answer = answer_chain.invoke({"query": query, "context": context})
return {
"answer" : answer,
"source_documents" : source_docs,
"context" : context,
}
return RunnableLambda(run)
# ══════════════════════════════════════════════════════════════════════════════
# CONVERSATIONAL RAG CHAIN
# ══════════════════════════════════════════════════════════════════════════════
def build_conversational_chain(
vectorstore_dir : Path = VECTORSTORE_DIR,
rerank : bool = True,
n_results : int = DEFAULT_N_RESULTS,
max_chars : int = DEFAULT_MAX_CHARS,
filters : dict = None,
llm = None,
model_name : str = None,
memory_k : int = MEMORY_WINDOW_K,
):
"""
Conversational RAG chain with sliding-window memory.
Returns ConversationalRAGChain with:
.invoke({"query": str}) β†’ str
.clear_history() β†’ resets conversation
.history β†’ list of (question, answer) tuples
"""
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from src.retriever import FinancialRetriever
retriever = FinancialRetriever(
vectorstore_dir = vectorstore_dir,
rerank = rerank,
)
_llm = llm or get_llm(model_name)
prompt = PromptTemplate.from_template(CONVERSATIONAL_PROMPT_TEMPLATE)
parser = StrOutputParser()
return ConversationalRAGChain(
retriever = retriever,
llm = _llm,
prompt = prompt,
parser = parser,
n_results = n_results,
max_chars = max_chars,
filters = filters,
memory_k = memory_k,
)
class ConversationalRAGChain:
"""Stateful RAG chain remembering the last `memory_k` turns."""
def __init__(self, retriever, llm, prompt, parser, n_results,
max_chars, filters, memory_k):
self._retriever = retriever
self._llm = llm
self._prompt = prompt
self._parser = parser
self._n_results = n_results
self._max_chars = max_chars
self._filters = filters
self._memory_k = memory_k
self._history : list[tuple[str, str]] = []
@property
def history(self) -> list[tuple[str, str]]:
return list(self._history)
def clear_history(self):
self._history = []
log.info("Conversation history cleared.")
def _format_history(self) -> str:
if not self._history:
return "(no prior conversation)"
lines = []
for q, a in self._history[-self._memory_k:]:
lines += [f"Human: {q}", f"Assistant: {a}"]
return "\n".join(lines)
def invoke(self, inputs: dict) -> str:
query = inputs["query"]
f = inputs.get("filters", self._filters)
chunks = self._retriever.retrieve(query, n_results=self._n_results, filters=f)
context = self._retriever.build_context(chunks, max_chars=self._max_chars)
chain = self._prompt | self._llm | self._parser
answer = chain.invoke({
"query" : query,
"context": context,
"history": self._format_history(),
})
self._history.append((query, answer))
return answer