Financial_RAG / app.py
amaherovskyi's picture
Update app.py
fdcc774 verified
import gradio as gr
import logging
from typing import Tuple, Optional, List, Dict, Any
from data_loader import load_finqa_dataset
from vector_store import init_model, build_embeddings, build_faiss_index, build_bm25
from pipeline import init_reranker, init_groq, rag_pipeline
# Logging Configuration
logger = logging.getLogger("rag_app")
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# Global cached resources
_resources_initialized = False
_documents = None
_model = None
_faiss_index = None
_bm25_index = None
_reranker_model = None
FINQA_SPLIT = "train"
DEFAULT_RETRIEVE_K = 30
DEFAULT_RERANK_K = 5
# Lazy Initialization
def init_once() -> Tuple[
List[Dict[str, Any]], Any, Any, Any, Any
]:
global _resources_initialized, _documents, _model, _faiss_index, _bm25_index, _reranker_model
if _resources_initialized:
return _documents, _model, _faiss_index, _bm25_index, _reranker_model
logger.info("Loading FinQA dataset…")
_documents = load_finqa_dataset(split=FINQA_SPLIT)
logger.info(f"Loaded {len(_documents)} documents.")
logger.info("Initializing embedding model…")
_model = init_model()
logger.info("Building embeddings (loading or creating)…")
embeddings = build_embeddings(_documents, _model)
logger.info("Initializing FAISS index…")
_faiss_index = build_faiss_index(embeddings)
logger.info("Building BM25 index…")
_bm25_index = build_bm25(_documents)
logger.info("Loading cross-encoder reranker…")
_reranker_model = init_reranker()
_resources_initialized = True
return _documents, _model, _faiss_index, _bm25_index, _reranker_model
# Main RAG Handler
def answer_question(
llm_key: str,
query: str,
search_mode: str
) -> Tuple[str, str]:
if not query.strip():
return "Please enter a question.", ""
if not llm_key.strip():
return "Please provide a valid GROQ API key.", ""
# load cached resources
documents, model, faiss_idx, bm25_idx, reranker = init_once()
# verify API key
try:
llm_client = init_groq(api_key=llm_key)
except Exception:
return "Invalid GROQ API key.", ""
# run RAG pipeline
result = rag_pipeline(
query=query,
reranker_model=reranker,
llm_client=llm_client,
documents=documents,
model=model,
faiss_index=faiss_idx,
bm25_index=bm25_idx,
retrieve_mode=search_mode,
retrieve_k=DEFAULT_RETRIEVE_K,
rerank_k=DEFAULT_RERANK_K
)
# logs
logger.info("=== RERANKED CONTEXT ===")
for d in result["reranked_docs"]:
logger.info(f"ID: {d['id']} | score={d['rerank_score']:.4f}")
# prepare markdown preview
docs_preview = ""
for d in result["reranked_docs"]:
docs_preview += (
f"### Document ID: {d['id']}\n"
f"{d['text'][:500]}...\n\n"
)
return result["answer"], docs_preview
# Gradio UI
with gr.Blocks(title="Financial RAG QA System") as demo:
gr.Markdown("""
# 📘 Financial RAG Question Answering System
This app answers financial questions using the **FinQA dataset**.
### 🔧 How it works
1. Retrieves relevant passages using BM25 / Embeddings / Hybrid
2. Reranks them using a Cross-Encoder
3. Sends only the best documents to the LLM (Groq Llama 3.3)
4. Produces an accurate, grounded answer
### 🧠 Best for:
- Financial statement analysis
- Annual report question answering
- Extracting metrics (revenue, cash flow, expenses)
- Multi-step reasoning on financial texts
""")
llm_key = gr.Textbox(
label="GROQ API Key",
placeholder="Enter your Groq API key",
type="password"
)
query = gr.Textbox(
label="Your Question",
placeholder="Example: *What was total revenue in 2019 for Apple?*",
lines=2
)
search_mode = gr.Radio(
choices=["hybrid", "bm25", "semantic"],
value="hybrid",
label="Retrieval Mode"
)
gr.Markdown("""
**Hybrid** → Best overall (semantic + keyword)
**BM25** → Best for exact numeric queries (years, dollars, names)
**Semantic** → Best for conceptual queries
""")
run_button = gr.Button("Ask")
answer_box = gr.Markdown()
docs_box = gr.Markdown()
run_button.click(
fn=answer_question,
inputs=[llm_key, query, search_mode],
outputs=[answer_box, docs_box]
)
demo.launch()