SohaAyub's picture
Update app.py
a05ff89 verified
import os
import numpy as np
import faiss
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
# ================================
# 1. Initialize Groq Client
# ================================
# Groq API Key must be set in Hugging Face Space Secrets
# Go to Settings -> Secrets and add Finance_API
GROQ_API_KEY = os.environ.get("Finance_API")
if not GROQ_API_KEY:
raise ValueError("Please set the Finance_API secret in your Hugging Face Space.")
client = Groq(api_key=GROQ_API_KEY)
# ================================
# 2. Load Datasets (SCRIPT-FREE)
# ================================
print("Loading datasets...")
# Financial News dataset
news_ds = load_dataset("ashraq/financial-news", split="train[:500]") # reduce size for runtime
# Create a small QA dataset from Financial News (simple heuristic)
qa_docs = []
for item in news_ds:
headline = item["headline"]
sentences = headline.split(". ")
for s in sentences[:2]:
if len(s) > 30:
qa_docs.append(f"Question: What does the news say?\nAnswer: {s}")
documents = []
# Add Financial News
for item in news_ds:
documents.append(item["headline"])
# Add QA docs
documents.extend(qa_docs)
print(f"Total documents loaded: {len(documents)}")
# ================================
# 3. Load Embedding Model
# ================================
print("Loading embedding model...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
# ================================
# 4. Create Embeddings
# ================================
print("Generating embeddings...")
embeddings = embedder.encode(documents, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")
# ================================
# 5. Build FAISS Index
# ================================
print("Building FAISS index...")
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
print("FAISS index built successfully!")
# ================================
# 6. Retrieve Relevant Documents
# ================================
def retrieve_docs(query, k=5):
query_embedding = embedder.encode([query])
query_embedding = np.array(query_embedding).astype("float32")
distances, indices = index.search(query_embedding, k)
retrieved = [documents[i] for i in indices[0]]
return retrieved
# ================================
# 7. Ask Groq LLaMA
# ================================
def ask_llama(context, question):
prompt = f"""
You are a professional Financial Investment Advisor AI.
You must answer based ONLY on the provided context.
If the context is not enough, say clearly: "I don't have enough data."
Context:
{context}
User Question:
{question}
Answer Format:
1. Investment Suggestion
2. Risk Level (Low/Medium/High)
3. Reasoning (based on context)
4. Recommendation (where to invest more)
"""
response = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama-3.3-70b-versatile"
)
return response.choices[0].message.content
# ================================
# 8. Full RAG Pipeline
# ================================
def rag_pipeline(user_question):
retrieved_docs = retrieve_docs(user_question, k=5)
context = "\n\n".join(retrieved_docs)
answer = ask_llama(context, user_question)
return answer
# ================================
# 9. Gradio UI
# ================================
def chatbot(question):
return rag_pipeline(question)
interface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=3, placeholder="Example: I have $5000, where should I invest?"),
outputs="text",
title="πŸ“ˆ RAG-Based Financial Investment Advisor",
description="Uses Financial News + FAISS + Groq LLaMA (Script-Free datasets)"
)
interface.launch()