Aarya003's picture
Upload app.py
52c767e verified
import streamlit as st
import os
import pandas as pd
import yfinance as yf
from pydantic import BaseModel, Field
from typing import List, Literal, Optional
from llama_index.core import VectorStoreIndex, Settings
from llama_index.vector_stores.pinecone import PineconeVectorStore
from pinecone import Pinecone
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.program.openai import OpenAIPydanticProgram
from llama_index.llms.openai import OpenAI
from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter
# --- 1. CONFIGURATION ---
st.set_page_config(page_title="Financial Agent (Strict Logic)", page_icon="πŸ“ˆ", layout="wide")
# Ensure keys exist
if "OPENAI_API_KEY" not in os.environ:
st.error("❌ OPENAI_API_KEY missing.")
st.stop()
# --- 2. DATA MODELS (From your snippet) ---
class AgentResponse(BaseModel):
answer: str
sources: List[str]
context_used: List[str]
class TickerExtraction(BaseModel):
symbols: List[str] = Field(description="List of stock tickers.")
class RoutePrediction(BaseModel):
tools: List[Literal["financial_rag", "market_data", "general_chat"]] = Field(description="Tools list")
# --- 3. CACHED INITIALIZATION ---
@st.cache_resource(show_spinner=False)
def initialize_resources():
print("πŸ”Œ Initializing Strict-Boundary Agent...")
# Setup LlamaIndex Settings
Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0)
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
# Load CSV
try:
nasdaq_df = pd.read_csv('nasdaq-listed.csv')
nasdaq_df.columns = [c.strip() for c in nasdaq_df.columns]
except:
nasdaq_df = pd.DataFrame()
# Connect to Pinecone
api_key = os.environ.get("PINECONE_API_KEY")
if not api_key: raise ValueError("Pinecone Key Missing")
pc = Pinecone(api_key=api_key)
index = VectorStoreIndex.from_vector_store(
vector_store=PineconeVectorStore(pinecone_index=pc.Index("financial-rag-agent"))
)
return nasdaq_df, index
# --- 4. HELPER FUNCTIONS (From your snippet) ---
def get_symbol_from_csv(query_str: str, df) -> Optional[str]:
if df.empty: return None
query_str = query_str.strip().upper()
if query_str in df['Symbol'].values: return query_str
matches = df[df['Security Name'].str.upper().str.contains(query_str, na=False)]
if not matches.empty: return matches.loc[matches['Symbol'].str.len().idxmin()]['Symbol']
return None
def get_tickers_from_query(query: str, index, df) -> List[str]:
program = OpenAIPydanticProgram.from_defaults(
output_cls=TickerExtraction,
prompt_template_str="Identify all companies in query: {query_str}. Return list.",
llm=Settings.llm
)
raw_entities = program(query_str=query).symbols
valid_tickers = []
for entity in raw_entities:
ticker = get_symbol_from_csv(entity, df)
if not ticker and len(entity) <= 5: ticker = entity.upper()
if ticker: valid_tickers.append(ticker)
if not valid_tickers:
try:
nodes = index.as_retriever(similarity_top_k=1).retrieve(query)
if nodes and nodes[0].metadata.get("ticker"):
valid_tickers.append(nodes[0].metadata.get("ticker"))
except: pass
return list(set(valid_tickers))
# --- 5. TOOLS (From your snippet) ---
def get_market_data(query: str, index, df):
tickers = get_tickers_from_query(query, index, df)
if not tickers: return "No companies found."
results = []
for ticker in tickers:
try:
stock = yf.Ticker(ticker)
info = stock.info
data = {
"Ticker": ticker,
"Price": info.get('currentPrice', 'N/A'),
"Market Cap": info.get('marketCap', 'N/A'),
"PE Ratio": info.get('trailingPE', 'N/A'),
"52w High": info.get('fiftyTwoWeekHigh', 'N/A'),
"52w Low": info.get('fiftyTwoWeekLow', 'N/A'),
"Volume": info.get('volume', 'N/A'),
"Currency": info.get('currency', 'USD')
}
results.append(str(data))
except Exception as e:
results.append(f"{ticker}: Data Error ({e})")
return "\n".join(results)
def get_financial_rag(query: str, index, df):
target_tickers = get_tickers_from_query(query, index, df)
SUPPORTED = ["AAPL", "TSLA", "NVDA"]
payload = {"content": "", "sources": [], "raw_nodes": []}
for ticker in target_tickers:
if ticker not in SUPPORTED:
payload["content"] += f"\n[NOTE: No 10-K report available for {ticker}.]\n"
continue
filters = MetadataFilters(filters=[ExactMatchFilter(key="ticker", value=ticker)])
# Using logic from your snippet (similarity_top_k=3)
engine = index.as_query_engine(similarity_top_k=3, filters=filters)
resp = engine.query(query)
payload["content"] += f"\n--- {ticker} 10-K Data ---\n{resp.response}\n"
for n in resp.source_nodes:
payload["sources"].append(f"{n.metadata.get('company')} 10-K")
payload["raw_nodes"].append(n.node.get_text())
return payload
# --- 6. AGENT LOGIC (From your snippet) ---
def run_agent(user_query: str, index, df) -> AgentResponse:
# THE STRICT PROMPT YOU PROVIDED
router_prompt = """
Route the user query to the correct tool based on these strict definitions:
1. "financial_rag":
- Use for ANY question about a specific company's internal details.
- INCLUDES: Revenue, Profit, Income, CEO, Board Members, Risks, Strategy, Competitors, Legal Issues, History.
- Key Trigger: If the answer would be found in a PDF report or Wikipedia page, use this.
2. "market_data":
- Use ONLY for Real-Time Trading Metrics.
- INCLUDES: Current Price, Market Cap, PE Ratio, Trading Volume, 52-Week High/Low.
- EXCLUDES: Historical revenue or annual profit (Use financial_rag for those).
3. "general_chat":
- Use ONLY for non-business questions (e.g. "Hi", "Help").
- NEVER use this if a specific company (Tesla, Apple, Nvidia) is mentioned.
Query: {query_str}
"""
router = OpenAIPydanticProgram.from_defaults(
output_cls=RoutePrediction,
prompt_template_str=router_prompt,
llm=Settings.llm
)
tools = router(query_str=user_query).tools
results = {}
sources = []
context_used = []
if "market_data" in tools:
res = get_market_data(user_query, index, df)
results["market_data"] = res
context_used.append(res)
sources.append("Real-time Market Data")
if "financial_rag" in tools:
res = get_financial_rag(user_query, index, df)
results["financial_rag"] = res["content"]
sources.extend(res["sources"])
context_used.extend(res["raw_nodes"])
final_prompt = f"""
You are a Wall Street Financial Analyst. Answer the user request using the provided context.
Context Data:
{results}
Instructions:
1. Compare Metrics if multiple companies are listed.
2. Synthesize qualitative (Risks) and quantitative (Price) data.
3. Explicitly state if a report is missing.
4. Cite sources.
User Query: {user_query}
"""
response_text = Settings.llm.complete(final_prompt).text
return AgentResponse(
answer=response_text,
sources=list(set(sources)),
context_used=context_used
)
# --- 7. STREAMLIT UI ---
# Initialize Logic
with st.sidebar:
st.title("πŸ”§ System Status")
with st.spinner("Initializing Strict-Boundary Agent..."):
try:
nasdaq_df, pinecone_index = initialize_resources()
st.success("βœ… Brain Loaded")
st.success(f"βœ… {len(nasdaq_df)} Tickers Indexed")
except Exception as e:
st.error(f"Initialization Failed: {e}")
st.stop()
st.markdown("---")
st.markdown("### 🎯 RAG Coverage")
st.code("AAPL\nTSLA\nNVDA")
st.title("πŸ“ˆ Financial Agent (Strict Logic)")
if "messages" not in st.session_state:
st.session_state.messages = []
# Display History
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "sources" in message:
with st.expander("πŸ“š Sources & Context"):
st.write(message["sources"])
for i, c in enumerate(message["context"][:3]): # Limit preview
st.text(f"Snippet {i+1}: {str(c)[:300]}...")
# Input Handler
if prompt := st.chat_input("Enter query..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.status("🧠 Analyst is thinking...", expanded=True) as status:
try:
# RUN THE SAVED LOGIC
response = run_agent(prompt, pinecone_index, nasdaq_df)
status.update(label="βœ… Complete", state="complete", expanded=False)
st.markdown(response.answer)
# Audit Trail
with st.expander("πŸ” Audit Trail (Full Context)"):
st.write("**Sources:**", response.sources)
st.write("**Raw Retrieval:**")
for ctx in response.context_used:
st.text(str(ctx))
st.session_state.messages.append({
"role": "assistant",
"content": response.answer,
"sources": response.sources,
"context": response.context_used
})
except Exception as e:
st.error(f"Error: {e}")
status.update(label="❌ Error", state="error")