Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import pandas as pd | |
| import yfinance as yf | |
| from pydantic import BaseModel, Field | |
| from typing import List, Literal, Optional | |
| from llama_index.core import VectorStoreIndex, Settings | |
| from llama_index.vector_stores.pinecone import PineconeVectorStore | |
| from pinecone import Pinecone | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.program.openai import OpenAIPydanticProgram | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter | |
| # --- 1. CONFIGURATION --- | |
| st.set_page_config(page_title="Financial Agent (Strict Logic)", page_icon="π", layout="wide") | |
| # Ensure keys exist | |
| if "OPENAI_API_KEY" not in os.environ: | |
| st.error("β OPENAI_API_KEY missing.") | |
| st.stop() | |
| # --- 2. DATA MODELS (From your snippet) --- | |
| class AgentResponse(BaseModel): | |
| answer: str | |
| sources: List[str] | |
| context_used: List[str] | |
| class TickerExtraction(BaseModel): | |
| symbols: List[str] = Field(description="List of stock tickers.") | |
| class RoutePrediction(BaseModel): | |
| tools: List[Literal["financial_rag", "market_data", "general_chat"]] = Field(description="Tools list") | |
| # --- 3. CACHED INITIALIZATION --- | |
| def initialize_resources(): | |
| print("π Initializing Strict-Boundary Agent...") | |
| # Setup LlamaIndex Settings | |
| Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0) | |
| Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small") | |
| # Load CSV | |
| try: | |
| nasdaq_df = pd.read_csv('nasdaq-listed.csv') | |
| nasdaq_df.columns = [c.strip() for c in nasdaq_df.columns] | |
| except: | |
| nasdaq_df = pd.DataFrame() | |
| # Connect to Pinecone | |
| api_key = os.environ.get("PINECONE_API_KEY") | |
| if not api_key: raise ValueError("Pinecone Key Missing") | |
| pc = Pinecone(api_key=api_key) | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=PineconeVectorStore(pinecone_index=pc.Index("financial-rag-agent")) | |
| ) | |
| return nasdaq_df, index | |
| # --- 4. HELPER FUNCTIONS (From your snippet) --- | |
| def get_symbol_from_csv(query_str: str, df) -> Optional[str]: | |
| if df.empty: return None | |
| query_str = query_str.strip().upper() | |
| if query_str in df['Symbol'].values: return query_str | |
| matches = df[df['Security Name'].str.upper().str.contains(query_str, na=False)] | |
| if not matches.empty: return matches.loc[matches['Symbol'].str.len().idxmin()]['Symbol'] | |
| return None | |
| def get_tickers_from_query(query: str, index, df) -> List[str]: | |
| program = OpenAIPydanticProgram.from_defaults( | |
| output_cls=TickerExtraction, | |
| prompt_template_str="Identify all companies in query: {query_str}. Return list.", | |
| llm=Settings.llm | |
| ) | |
| raw_entities = program(query_str=query).symbols | |
| valid_tickers = [] | |
| for entity in raw_entities: | |
| ticker = get_symbol_from_csv(entity, df) | |
| if not ticker and len(entity) <= 5: ticker = entity.upper() | |
| if ticker: valid_tickers.append(ticker) | |
| if not valid_tickers: | |
| try: | |
| nodes = index.as_retriever(similarity_top_k=1).retrieve(query) | |
| if nodes and nodes[0].metadata.get("ticker"): | |
| valid_tickers.append(nodes[0].metadata.get("ticker")) | |
| except: pass | |
| return list(set(valid_tickers)) | |
| # --- 5. TOOLS (From your snippet) --- | |
| def get_market_data(query: str, index, df): | |
| tickers = get_tickers_from_query(query, index, df) | |
| if not tickers: return "No companies found." | |
| results = [] | |
| for ticker in tickers: | |
| try: | |
| stock = yf.Ticker(ticker) | |
| info = stock.info | |
| data = { | |
| "Ticker": ticker, | |
| "Price": info.get('currentPrice', 'N/A'), | |
| "Market Cap": info.get('marketCap', 'N/A'), | |
| "PE Ratio": info.get('trailingPE', 'N/A'), | |
| "52w High": info.get('fiftyTwoWeekHigh', 'N/A'), | |
| "52w Low": info.get('fiftyTwoWeekLow', 'N/A'), | |
| "Volume": info.get('volume', 'N/A'), | |
| "Currency": info.get('currency', 'USD') | |
| } | |
| results.append(str(data)) | |
| except Exception as e: | |
| results.append(f"{ticker}: Data Error ({e})") | |
| return "\n".join(results) | |
| def get_financial_rag(query: str, index, df): | |
| target_tickers = get_tickers_from_query(query, index, df) | |
| SUPPORTED = ["AAPL", "TSLA", "NVDA"] | |
| payload = {"content": "", "sources": [], "raw_nodes": []} | |
| for ticker in target_tickers: | |
| if ticker not in SUPPORTED: | |
| payload["content"] += f"\n[NOTE: No 10-K report available for {ticker}.]\n" | |
| continue | |
| filters = MetadataFilters(filters=[ExactMatchFilter(key="ticker", value=ticker)]) | |
| # Using logic from your snippet (similarity_top_k=3) | |
| engine = index.as_query_engine(similarity_top_k=3, filters=filters) | |
| resp = engine.query(query) | |
| payload["content"] += f"\n--- {ticker} 10-K Data ---\n{resp.response}\n" | |
| for n in resp.source_nodes: | |
| payload["sources"].append(f"{n.metadata.get('company')} 10-K") | |
| payload["raw_nodes"].append(n.node.get_text()) | |
| return payload | |
| # --- 6. AGENT LOGIC (From your snippet) --- | |
| def run_agent(user_query: str, index, df) -> AgentResponse: | |
| # THE STRICT PROMPT YOU PROVIDED | |
| router_prompt = """ | |
| Route the user query to the correct tool based on these strict definitions: | |
| 1. "financial_rag": | |
| - Use for ANY question about a specific company's internal details. | |
| - INCLUDES: Revenue, Profit, Income, CEO, Board Members, Risks, Strategy, Competitors, Legal Issues, History. | |
| - Key Trigger: If the answer would be found in a PDF report or Wikipedia page, use this. | |
| 2. "market_data": | |
| - Use ONLY for Real-Time Trading Metrics. | |
| - INCLUDES: Current Price, Market Cap, PE Ratio, Trading Volume, 52-Week High/Low. | |
| - EXCLUDES: Historical revenue or annual profit (Use financial_rag for those). | |
| 3. "general_chat": | |
| - Use ONLY for non-business questions (e.g. "Hi", "Help"). | |
| - NEVER use this if a specific company (Tesla, Apple, Nvidia) is mentioned. | |
| Query: {query_str} | |
| """ | |
| router = OpenAIPydanticProgram.from_defaults( | |
| output_cls=RoutePrediction, | |
| prompt_template_str=router_prompt, | |
| llm=Settings.llm | |
| ) | |
| tools = router(query_str=user_query).tools | |
| results = {} | |
| sources = [] | |
| context_used = [] | |
| if "market_data" in tools: | |
| res = get_market_data(user_query, index, df) | |
| results["market_data"] = res | |
| context_used.append(res) | |
| sources.append("Real-time Market Data") | |
| if "financial_rag" in tools: | |
| res = get_financial_rag(user_query, index, df) | |
| results["financial_rag"] = res["content"] | |
| sources.extend(res["sources"]) | |
| context_used.extend(res["raw_nodes"]) | |
| final_prompt = f""" | |
| You are a Wall Street Financial Analyst. Answer the user request using the provided context. | |
| Context Data: | |
| {results} | |
| Instructions: | |
| 1. Compare Metrics if multiple companies are listed. | |
| 2. Synthesize qualitative (Risks) and quantitative (Price) data. | |
| 3. Explicitly state if a report is missing. | |
| 4. Cite sources. | |
| User Query: {user_query} | |
| """ | |
| response_text = Settings.llm.complete(final_prompt).text | |
| return AgentResponse( | |
| answer=response_text, | |
| sources=list(set(sources)), | |
| context_used=context_used | |
| ) | |
| # --- 7. STREAMLIT UI --- | |
| # Initialize Logic | |
| with st.sidebar: | |
| st.title("π§ System Status") | |
| with st.spinner("Initializing Strict-Boundary Agent..."): | |
| try: | |
| nasdaq_df, pinecone_index = initialize_resources() | |
| st.success("β Brain Loaded") | |
| st.success(f"β {len(nasdaq_df)} Tickers Indexed") | |
| except Exception as e: | |
| st.error(f"Initialization Failed: {e}") | |
| st.stop() | |
| st.markdown("---") | |
| st.markdown("### π― RAG Coverage") | |
| st.code("AAPL\nTSLA\nNVDA") | |
| st.title("π Financial Agent (Strict Logic)") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display History | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if "sources" in message: | |
| with st.expander("π Sources & Context"): | |
| st.write(message["sources"]) | |
| for i, c in enumerate(message["context"][:3]): # Limit preview | |
| st.text(f"Snippet {i+1}: {str(c)[:300]}...") | |
| # Input Handler | |
| if prompt := st.chat_input("Enter query..."): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.status("π§ Analyst is thinking...", expanded=True) as status: | |
| try: | |
| # RUN THE SAVED LOGIC | |
| response = run_agent(prompt, pinecone_index, nasdaq_df) | |
| status.update(label="β Complete", state="complete", expanded=False) | |
| st.markdown(response.answer) | |
| # Audit Trail | |
| with st.expander("π Audit Trail (Full Context)"): | |
| st.write("**Sources:**", response.sources) | |
| st.write("**Raw Retrieval:**") | |
| for ctx in response.context_used: | |
| st.text(str(ctx)) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response.answer, | |
| "sources": response.sources, | |
| "context": response.context_used | |
| }) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| status.update(label="β Error", state="error") | |