Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import os | |
| import requests | |
| import json | |
| # --- 1. Load the RAG Components and Caching Setup --- | |
| # The expensive operations of loading the dataset, creating embeddings, | |
| # and building the FAISS index are performed once when the app starts, | |
| # not on every chat request. This acts as a form of caching. | |
| # We use a Sentence Transformer model for creating embeddings. | |
| embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # --- 2. Load and Prepare the Dataset --- | |
| # The dataset contains financial Q&A pairs. | |
| print("Loading dataset...") | |
| dataset = load_dataset("FinLang/investopedia-embedding-dataset", split="train", streaming=True) | |
| print("Dataset loaded.") | |
| # --- 3. Build the FAISS Index --- | |
| # This is a highly efficient way to search for similar vectors. | |
| print("Building FAISS index...") | |
| #texts = [example['Answer'] for example in dataset.take(2000)] # Use a subset for speed | |
| texts = [example['Answer'] for example in dataset] # Use a subset for speed | |
| embeddings = embedding_model.encode(texts) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(np.array(embeddings).astype('float32')) | |
| print("FAISS index built.") | |
| # --- 4. RAG Pipeline Functions --- | |
| def retrieve_documents(query, k=5): | |
| """ | |
| Retrieves the top k most relevant documents from the FAISS index based on a query. | |
| """ | |
| query_embedding = embedding_model.encode([query])[0] | |
| D, I = index.search(np.array([query_embedding]).astype('float32'), k) | |
| retrieved_docs = [texts[i] for i in I[0]] | |
| return retrieved_docs | |
| def respond(message, chat_history): | |
| """ | |
| Main function for the Gradio interface. It orchestrates the RAG process | |
| (retrieval and generation) and returns the bot's response. | |
| """ | |
| # The ChatInterface with type="messages" now sends history as a list of dicts. | |
| # We need to transform it for the Gemini API call. | |
| conversation_history = [] | |
| # Append the chat history, mapping Gradio's "assistant" role to Gemini's "model" | |
| for turn in chat_history: | |
| role = "user" if turn["role"] == "user" else "model" | |
| conversation_history.append({"role": role, "parts": [{"text": turn["content"]}]}) | |
| # Add the current user message to the history | |
| conversation_history.append({"role": "user", "parts": [{"text": message}]}) | |
| # Combine the current message with the conversation history to provide more context for retrieval. | |
| # This helps with vague follow-up questions like "what else?". | |
| retrieval_query = message | |
| if chat_history: | |
| # Combine the last few turns to form a more complete query | |
| combined_context = " ".join([turn["content"] for turn in chat_history[-2:]]) | |
| retrieval_query = f"{combined_context} {message}" | |
| # 1. Retrieve documents based on the combined query | |
| retrieved_docs = retrieve_documents(retrieval_query) | |
| context_text = "\n".join(retrieved_docs) if retrieved_docs else "no relevant context found" | |
| # Define the system prompt with the retrieved context | |
| system_prompt = ( | |
| "You are a financial assistant for question-answering tasks related to finance or related topics only. " | |
| "Do not answer questions related to any other topics except finance. " | |
| "Use the following pieces of retrieved context to answer the question. If you don't know the answer, say that you don't know. " | |
| "Use three sentences maximum and keep the answer concise. " | |
| "If the question is not clear ask follow up questions. " | |
| f"\n\nContext:\n{context_text}" | |
| ) | |
| # API endpoint and payload for the Gemini Flash model | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| return "GEMINI_API_KEY environment variable not set. Please add it to your Hugging Face Space secrets." | |
| api_url = f'https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash-preview-05-20:generateContent?key={api_key}' | |
| payload = { | |
| "contents": conversation_history, | |
| "systemInstruction": { | |
| "parts": [{"text": system_prompt}] | |
| } | |
| } | |
| try: | |
| response = requests.post(api_url, json=payload) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Extract the text response | |
| if result and 'candidates' in result and result['candidates'][0]['content']['parts'][0]['text']: | |
| bot_response = result['candidates'][0]['content']['parts'][0]['text'] | |
| else: | |
| bot_response = "I couldn't provide an answer based on the available information. Please try rephrasing your question or ask about a different topic." | |
| except requests.exceptions.RequestException as e: | |
| print(f"API request failed: {e}") | |
| bot_response = "An error occurred while connecting to the AI model. Please try again later." | |
| except json.JSONDecodeError: | |
| bot_response = "An error occurred while parsing the API response." | |
| # Return ONLY the bot's response as a string. Gradio handles adding it to the history. | |
| return bot_response | |
| # --- 5. Create the Gradio Interface --- | |
| # The Gradio ChatInterface component provides a user-friendly way to | |
| # interact with our RAG pipeline. | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="Financial RAG Chatbot", | |
| description="Ask me a question about financial topics.", | |
| # Explicitly set the type to "messages" to avoid future deprecation warnings | |
| type="messages", | |
| # Specify the textbox component to prevent multimodal behavior | |
| textbox=gr.Textbox(placeholder="Ask me a question...", container=False, scale=7) | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |