Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import logging | |
| from typing import Tuple, Optional, List, Dict, Any | |
| from data_loader import load_finqa_dataset | |
| from vector_store import init_model, build_embeddings, build_faiss_index, build_bm25 | |
| from pipeline import init_reranker, init_groq, rag_pipeline | |
| # Logging Configuration | |
| logger = logging.getLogger("rag_app") | |
| logger.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| # Global cached resources | |
| _resources_initialized = False | |
| _documents = None | |
| _model = None | |
| _faiss_index = None | |
| _bm25_index = None | |
| _reranker_model = None | |
| FINQA_SPLIT = "train" | |
| DEFAULT_RETRIEVE_K = 30 | |
| DEFAULT_RERANK_K = 5 | |
| # Lazy Initialization | |
| def init_once() -> Tuple[ | |
| List[Dict[str, Any]], Any, Any, Any, Any | |
| ]: | |
| global _resources_initialized, _documents, _model, _faiss_index, _bm25_index, _reranker_model | |
| if _resources_initialized: | |
| return _documents, _model, _faiss_index, _bm25_index, _reranker_model | |
| logger.info("Loading FinQA dataset…") | |
| _documents = load_finqa_dataset(split=FINQA_SPLIT) | |
| logger.info(f"Loaded {len(_documents)} documents.") | |
| logger.info("Initializing embedding model…") | |
| _model = init_model() | |
| logger.info("Building embeddings (loading or creating)…") | |
| embeddings = build_embeddings(_documents, _model) | |
| logger.info("Initializing FAISS index…") | |
| _faiss_index = build_faiss_index(embeddings) | |
| logger.info("Building BM25 index…") | |
| _bm25_index = build_bm25(_documents) | |
| logger.info("Loading cross-encoder reranker…") | |
| _reranker_model = init_reranker() | |
| _resources_initialized = True | |
| return _documents, _model, _faiss_index, _bm25_index, _reranker_model | |
| # Main RAG Handler | |
| def answer_question( | |
| llm_key: str, | |
| query: str, | |
| search_mode: str | |
| ) -> Tuple[str, str]: | |
| if not query.strip(): | |
| return "Please enter a question.", "" | |
| if not llm_key.strip(): | |
| return "Please provide a valid GROQ API key.", "" | |
| # load cached resources | |
| documents, model, faiss_idx, bm25_idx, reranker = init_once() | |
| # verify API key | |
| try: | |
| llm_client = init_groq(api_key=llm_key) | |
| except Exception: | |
| return "Invalid GROQ API key.", "" | |
| # run RAG pipeline | |
| result = rag_pipeline( | |
| query=query, | |
| reranker_model=reranker, | |
| llm_client=llm_client, | |
| documents=documents, | |
| model=model, | |
| faiss_index=faiss_idx, | |
| bm25_index=bm25_idx, | |
| retrieve_mode=search_mode, | |
| retrieve_k=DEFAULT_RETRIEVE_K, | |
| rerank_k=DEFAULT_RERANK_K | |
| ) | |
| # logs | |
| logger.info("=== RERANKED CONTEXT ===") | |
| for d in result["reranked_docs"]: | |
| logger.info(f"ID: {d['id']} | score={d['rerank_score']:.4f}") | |
| # prepare markdown preview | |
| docs_preview = "" | |
| for d in result["reranked_docs"]: | |
| docs_preview += ( | |
| f"### Document ID: {d['id']}\n" | |
| f"{d['text'][:500]}...\n\n" | |
| ) | |
| return result["answer"], docs_preview | |
| # Gradio UI | |
| with gr.Blocks(title="Financial RAG QA System") as demo: | |
| gr.Markdown(""" | |
| # 📘 Financial RAG Question Answering System | |
| This app answers financial questions using the **FinQA dataset**. | |
| ### 🔧 How it works | |
| 1. Retrieves relevant passages using BM25 / Embeddings / Hybrid | |
| 2. Reranks them using a Cross-Encoder | |
| 3. Sends only the best documents to the LLM (Groq Llama 3.3) | |
| 4. Produces an accurate, grounded answer | |
| ### 🧠 Best for: | |
| - Financial statement analysis | |
| - Annual report question answering | |
| - Extracting metrics (revenue, cash flow, expenses) | |
| - Multi-step reasoning on financial texts | |
| """) | |
| llm_key = gr.Textbox( | |
| label="GROQ API Key", | |
| placeholder="Enter your Groq API key", | |
| type="password" | |
| ) | |
| query = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Example: *What was total revenue in 2019 for Apple?*", | |
| lines=2 | |
| ) | |
| search_mode = gr.Radio( | |
| choices=["hybrid", "bm25", "semantic"], | |
| value="hybrid", | |
| label="Retrieval Mode" | |
| ) | |
| gr.Markdown(""" | |
| **Hybrid** → Best overall (semantic + keyword) | |
| **BM25** → Best for exact numeric queries (years, dollars, names) | |
| **Semantic** → Best for conceptual queries | |
| """) | |
| run_button = gr.Button("Ask") | |
| answer_box = gr.Markdown() | |
| docs_box = gr.Markdown() | |
| run_button.click( | |
| fn=answer_question, | |
| inputs=[llm_key, query, search_mode], | |
| outputs=[answer_box, docs_box] | |
| ) | |
| demo.launch() | |