nikhmr1235's picture
Update app.py
61cfe47 verified
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import os
import requests
import json
# --- 1. Load the RAG Components and Caching Setup ---
# The expensive operations of loading the dataset, creating embeddings,
# and building the FAISS index are performed once when the app starts,
# not on every chat request. This acts as a form of caching.
# We use a Sentence Transformer model for creating embeddings.
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# --- 2. Load and Prepare the Dataset ---
# The dataset contains financial Q&A pairs.
print("Loading dataset...")
dataset = load_dataset("FinLang/investopedia-embedding-dataset", split="train", streaming=True)
print("Dataset loaded.")
# --- 3. Build the FAISS Index ---
# This is a highly efficient way to search for similar vectors.
print("Building FAISS index...")
#texts = [example['Answer'] for example in dataset.take(2000)] # Use a subset for speed
texts = [example['Answer'] for example in dataset] # Use a subset for speed
embeddings = embedding_model.encode(texts)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype('float32'))
print("FAISS index built.")
# --- 4. RAG Pipeline Functions ---
def retrieve_documents(query, k=5):
"""
Retrieves the top k most relevant documents from the FAISS index based on a query.
"""
query_embedding = embedding_model.encode([query])[0]
D, I = index.search(np.array([query_embedding]).astype('float32'), k)
retrieved_docs = [texts[i] for i in I[0]]
return retrieved_docs
def respond(message, chat_history):
"""
Main function for the Gradio interface. It orchestrates the RAG process
(retrieval and generation) and returns the bot's response.
"""
# The ChatInterface with type="messages" now sends history as a list of dicts.
# We need to transform it for the Gemini API call.
conversation_history = []
# Append the chat history, mapping Gradio's "assistant" role to Gemini's "model"
for turn in chat_history:
role = "user" if turn["role"] == "user" else "model"
conversation_history.append({"role": role, "parts": [{"text": turn["content"]}]})
# Add the current user message to the history
conversation_history.append({"role": "user", "parts": [{"text": message}]})
# Combine the current message with the conversation history to provide more context for retrieval.
# This helps with vague follow-up questions like "what else?".
retrieval_query = message
if chat_history:
# Combine the last few turns to form a more complete query
combined_context = " ".join([turn["content"] for turn in chat_history[-2:]])
retrieval_query = f"{combined_context} {message}"
# 1. Retrieve documents based on the combined query
retrieved_docs = retrieve_documents(retrieval_query)
context_text = "\n".join(retrieved_docs) if retrieved_docs else "no relevant context found"
# Define the system prompt with the retrieved context
system_prompt = (
"You are a financial assistant for question-answering tasks related to finance or related topics only. "
"Do not answer questions related to any other topics except finance. "
"Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know. "
"Use three sentences maximum and keep the answer concise. "
"If the question is not clear ask follow up questions. "
f"\n\nContext:\n{context_text}"
)
# API endpoint and payload for the Gemini Flash model
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return "GEMINI_API_KEY environment variable not set. Please add it to your Hugging Face Space secrets."
api_url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-05-20:generateContent?key={api_key}'
payload = {
"contents": conversation_history,
"systemInstruction": {
"parts": [{"text": system_prompt}]
}
}
try:
response = requests.post(api_url, json=payload)
response.raise_for_status()
result = response.json()
# Extract the text response
if result and 'candidates' in result and result['candidates'][0]['content']['parts'][0]['text']:
bot_response = result['candidates'][0]['content']['parts'][0]['text']
else:
bot_response = "I couldn't provide an answer based on the available information. Please try rephrasing your question or ask about a different topic."
except requests.exceptions.RequestException as e:
print(f"API request failed: {e}")
bot_response = "An error occurred while connecting to the AI model. Please try again later."
except json.JSONDecodeError:
bot_response = "An error occurred while parsing the API response."
# Return ONLY the bot's response as a string. Gradio handles adding it to the history.
return bot_response
# --- 5. Create the Gradio Interface ---
# The Gradio ChatInterface component provides a user-friendly way to
# interact with our RAG pipeline.
demo = gr.ChatInterface(
fn=respond,
title="Financial RAG Chatbot",
description="Ask me a question about financial topics.",
# Explicitly set the type to "messages" to avoid future deprecation warnings
type="messages",
# Specify the textbox component to prevent multimodal behavior
textbox=gr.Textbox(placeholder="Ask me a question...", container=False, scale=7)
)
if __name__ == "__main__":
demo.launch()