Aarya003 commited on
Commit
ff4a3f0
Β·
verified Β·
1 Parent(s): 52c767e

Upload app.py

Browse files
Files changed (1) hide show
  1. src/app.py +270 -0
src/app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import pandas as pd
4
+ import yfinance as yf
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Literal, Optional
7
+ from llama_index.core import VectorStoreIndex, Settings
8
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
9
+ from pinecone import Pinecone
10
+ from llama_index.embeddings.openai import OpenAIEmbedding
11
+ from llama_index.program.openai import OpenAIPydanticProgram
12
+ from llama_index.llms.openai import OpenAI
13
+ from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter
14
+
15
+ # --- 1. CONFIGURATION ---
16
+ st.set_page_config(page_title="Financial Agent (Strict Logic)", page_icon="πŸ“ˆ", layout="wide")
17
+
18
+ # Ensure keys exist
19
+ if "OPENAI_API_KEY" not in os.environ:
20
+ st.error("❌ OPENAI_API_KEY missing.")
21
+ st.stop()
22
+
23
+ # --- 2. DATA MODELS (From your snippet) ---
24
+ class AgentResponse(BaseModel):
25
+ answer: str
26
+ sources: List[str]
27
+ context_used: List[str]
28
+
29
+ class TickerExtraction(BaseModel):
30
+ symbols: List[str] = Field(description="List of stock tickers.")
31
+
32
+ class RoutePrediction(BaseModel):
33
+ tools: List[Literal["financial_rag", "market_data", "general_chat"]] = Field(description="Tools list")
34
+
35
+ # --- 3. CACHED INITIALIZATION ---
36
+ @st.cache_resource(show_spinner=False)
37
+ def initialize_resources():
38
+ print("πŸ”Œ Initializing Strict-Boundary Agent...")
39
+
40
+ # Setup LlamaIndex Settings
41
+ Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0)
42
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
43
+
44
+ # Load CSV
45
+ try:
46
+ nasdaq_df = pd.read_csv('nasdaq-listed.csv')
47
+ nasdaq_df.columns = [c.strip() for c in nasdaq_df.columns]
48
+ except:
49
+ nasdaq_df = pd.DataFrame()
50
+
51
+ # Connect to Pinecone
52
+ api_key = os.environ.get("PINECONE_API_KEY")
53
+ if not api_key: raise ValueError("Pinecone Key Missing")
54
+
55
+ pc = Pinecone(api_key=api_key)
56
+ index = VectorStoreIndex.from_vector_store(
57
+ vector_store=PineconeVectorStore(pinecone_index=pc.Index("financial-rag-agent"))
58
+ )
59
+
60
+ return nasdaq_df, index
61
+
62
+ # --- 4. HELPER FUNCTIONS (From your snippet) ---
63
+ def get_symbol_from_csv(query_str: str, df) -> Optional[str]:
64
+ if df.empty: return None
65
+ query_str = query_str.strip().upper()
66
+ if query_str in df['Symbol'].values: return query_str
67
+ matches = df[df['Security Name'].str.upper().str.contains(query_str, na=False)]
68
+ if not matches.empty: return matches.loc[matches['Symbol'].str.len().idxmin()]['Symbol']
69
+ return None
70
+
71
+ def get_tickers_from_query(query: str, index, df) -> List[str]:
72
+ program = OpenAIPydanticProgram.from_defaults(
73
+ output_cls=TickerExtraction,
74
+ prompt_template_str="Identify all companies in query: {query_str}. Return list.",
75
+ llm=Settings.llm
76
+ )
77
+ raw_entities = program(query_str=query).symbols
78
+ valid_tickers = []
79
+ for entity in raw_entities:
80
+ ticker = get_symbol_from_csv(entity, df)
81
+ if not ticker and len(entity) <= 5: ticker = entity.upper()
82
+ if ticker: valid_tickers.append(ticker)
83
+
84
+ if not valid_tickers:
85
+ try:
86
+ nodes = index.as_retriever(similarity_top_k=1).retrieve(query)
87
+ if nodes and nodes[0].metadata.get("ticker"):
88
+ valid_tickers.append(nodes[0].metadata.get("ticker"))
89
+ except: pass
90
+ return list(set(valid_tickers))
91
+
92
+ # --- 5. TOOLS (From your snippet) ---
93
+ def get_market_data(query: str, index, df):
94
+ tickers = get_tickers_from_query(query, index, df)
95
+ if not tickers: return "No companies found."
96
+ results = []
97
+ for ticker in tickers:
98
+ try:
99
+ stock = yf.Ticker(ticker)
100
+ info = stock.info
101
+ data = {
102
+ "Ticker": ticker,
103
+ "Price": info.get('currentPrice', 'N/A'),
104
+ "Market Cap": info.get('marketCap', 'N/A'),
105
+ "PE Ratio": info.get('trailingPE', 'N/A'),
106
+ "52w High": info.get('fiftyTwoWeekHigh', 'N/A'),
107
+ "52w Low": info.get('fiftyTwoWeekLow', 'N/A'),
108
+ "Volume": info.get('volume', 'N/A'),
109
+ "Currency": info.get('currency', 'USD')
110
+ }
111
+ results.append(str(data))
112
+ except Exception as e:
113
+ results.append(f"{ticker}: Data Error ({e})")
114
+ return "\n".join(results)
115
+
116
+ def get_financial_rag(query: str, index, df):
117
+ target_tickers = get_tickers_from_query(query, index, df)
118
+ SUPPORTED = ["AAPL", "TSLA", "NVDA"]
119
+ payload = {"content": "", "sources": [], "raw_nodes": []}
120
+
121
+ for ticker in target_tickers:
122
+ if ticker not in SUPPORTED:
123
+ payload["content"] += f"\n[NOTE: No 10-K report available for {ticker}.]\n"
124
+ continue
125
+
126
+ filters = MetadataFilters(filters=[ExactMatchFilter(key="ticker", value=ticker)])
127
+ # Using logic from your snippet (similarity_top_k=3)
128
+ engine = index.as_query_engine(similarity_top_k=3, filters=filters)
129
+ resp = engine.query(query)
130
+
131
+ payload["content"] += f"\n--- {ticker} 10-K Data ---\n{resp.response}\n"
132
+ for n in resp.source_nodes:
133
+ payload["sources"].append(f"{n.metadata.get('company')} 10-K")
134
+ payload["raw_nodes"].append(n.node.get_text())
135
+
136
+ return payload
137
+
138
+ # --- 6. AGENT LOGIC (From your snippet) ---
139
+ def run_agent(user_query: str, index, df) -> AgentResponse:
140
+ # THE STRICT PROMPT YOU PROVIDED
141
+ router_prompt = """
142
+ Route the user query to the correct tool based on these strict definitions:
143
+
144
+ 1. "financial_rag":
145
+ - Use for ANY question about a specific company's internal details.
146
+ - INCLUDES: Revenue, Profit, Income, CEO, Board Members, Risks, Strategy, Competitors, Legal Issues, History.
147
+ - Key Trigger: If the answer would be found in a PDF report or Wikipedia page, use this.
148
+
149
+ 2. "market_data":
150
+ - Use ONLY for Real-Time Trading Metrics.
151
+ - INCLUDES: Current Price, Market Cap, PE Ratio, Trading Volume, 52-Week High/Low.
152
+ - EXCLUDES: Historical revenue or annual profit (Use financial_rag for those).
153
+
154
+ 3. "general_chat":
155
+ - Use ONLY for non-business questions (e.g. "Hi", "Help").
156
+ - NEVER use this if a specific company (Tesla, Apple, Nvidia) is mentioned.
157
+
158
+ Query: {query_str}
159
+ """
160
+
161
+ router = OpenAIPydanticProgram.from_defaults(
162
+ output_cls=RoutePrediction,
163
+ prompt_template_str=router_prompt,
164
+ llm=Settings.llm
165
+ )
166
+ tools = router(query_str=user_query).tools
167
+
168
+ results = {}
169
+ sources = []
170
+ context_used = []
171
+
172
+ if "market_data" in tools:
173
+ res = get_market_data(user_query, index, df)
174
+ results["market_data"] = res
175
+ context_used.append(res)
176
+ sources.append("Real-time Market Data")
177
+
178
+ if "financial_rag" in tools:
179
+ res = get_financial_rag(user_query, index, df)
180
+ results["financial_rag"] = res["content"]
181
+ sources.extend(res["sources"])
182
+ context_used.extend(res["raw_nodes"])
183
+
184
+ final_prompt = f"""
185
+ You are a Wall Street Financial Analyst. Answer the user request using the provided context.
186
+
187
+ Context Data:
188
+ {results}
189
+
190
+ Instructions:
191
+ 1. Compare Metrics if multiple companies are listed.
192
+ 2. Synthesize qualitative (Risks) and quantitative (Price) data.
193
+ 3. Explicitly state if a report is missing.
194
+ 4. Cite sources.
195
+
196
+ User Query: {user_query}
197
+ """
198
+
199
+ response_text = Settings.llm.complete(final_prompt).text
200
+
201
+ return AgentResponse(
202
+ answer=response_text,
203
+ sources=list(set(sources)),
204
+ context_used=context_used
205
+ )
206
+
207
+ # --- 7. STREAMLIT UI ---
208
+ # Initialize Logic
209
+ with st.sidebar:
210
+ st.title("πŸ”§ System Status")
211
+ with st.spinner("Initializing Strict-Boundary Agent..."):
212
+ try:
213
+ nasdaq_df, pinecone_index = initialize_resources()
214
+ st.success("βœ… Brain Loaded")
215
+ st.success(f"βœ… {len(nasdaq_df)} Tickers Indexed")
216
+ except Exception as e:
217
+ st.error(f"Initialization Failed: {e}")
218
+ st.stop()
219
+
220
+ st.markdown("---")
221
+ st.markdown("### 🎯 RAG Coverage")
222
+ st.code("AAPL\nTSLA\nNVDA")
223
+
224
+ st.title("πŸ“ˆ Financial Agent (Strict Logic)")
225
+
226
+ if "messages" not in st.session_state:
227
+ st.session_state.messages = []
228
+
229
+ # Display History
230
+ for message in st.session_state.messages:
231
+ with st.chat_message(message["role"]):
232
+ st.markdown(message["content"])
233
+ if "sources" in message:
234
+ with st.expander("πŸ“š Sources & Context"):
235
+ st.write(message["sources"])
236
+ for i, c in enumerate(message["context"][:3]): # Limit preview
237
+ st.text(f"Snippet {i+1}: {str(c)[:300]}...")
238
+
239
+ # Input Handler
240
+ if prompt := st.chat_input("Enter query..."):
241
+ st.session_state.messages.append({"role": "user", "content": prompt})
242
+ with st.chat_message("user"):
243
+ st.markdown(prompt)
244
+
245
+ with st.chat_message("assistant"):
246
+ with st.status("🧠 Analyst is thinking...", expanded=True) as status:
247
+ try:
248
+ # RUN THE SAVED LOGIC
249
+ response = run_agent(prompt, pinecone_index, nasdaq_df)
250
+
251
+ status.update(label="βœ… Complete", state="complete", expanded=False)
252
+ st.markdown(response.answer)
253
+
254
+ # Audit Trail
255
+ with st.expander("πŸ” Audit Trail (Full Context)"):
256
+ st.write("**Sources:**", response.sources)
257
+ st.write("**Raw Retrieval:**")
258
+ for ctx in response.context_used:
259
+ st.text(str(ctx))
260
+
261
+ st.session_state.messages.append({
262
+ "role": "assistant",
263
+ "content": response.answer,
264
+ "sources": response.sources,
265
+ "context": response.context_used
266
+ })
267
+
268
+ except Exception as e:
269
+ st.error(f"Error: {e}")
270
+ status.update(label="❌ Error", state="error")