Aarya003 commited on
Commit
22055b7
ยท
verified ยท
1 Parent(s): 6171369

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +135 -128
src/streamlit_app.py CHANGED
@@ -12,88 +12,81 @@ from llama_index.program.openai import OpenAIPydanticProgram
12
  from llama_index.llms.openai import OpenAI
13
  from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter
14
 
15
- # --- 1. CONFIGURATION ---
16
- st.set_page_config(page_title="Financial Agent (Strict Logic)", page_icon="๐Ÿ“ˆ", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Ensure keys exist
19
  if "OPENAI_API_KEY" not in os.environ:
20
- st.error("โŒ OPENAI_API_KEY missing.")
21
  st.stop()
22
 
23
- # --- 2. DATA MODELS (From your snippet) ---
24
  class AgentResponse(BaseModel):
25
-
26
- """The AgentResponse class is a Pydantic model designed to structure the output of the financial agent. It ensures that every response from the agent contains not just the answer, but also the supporting evidence and lineage of data used.
27
-
28
- Attributes:
29
- answer (str): The final, synthesized natural language response generated by the LLM for the user.
30
- sources (List[str]): A list of high-level source names cited in the answer (e.g., "Tesla Inc 10-K", "Real-time Market Data"). This provides immediate transparency.
31
- context_used (List[str]): A list of the actual raw text chunks or data dictionaries retrieved from the tools (RAG or Market Data) and passed to the LLM. This is crucial for auditability and debugging."""
32
- answer: str
33
- sources: List[str]
34
- context_used: List[str]
35
 
36
  class TickerExtraction(BaseModel):
37
- """List of stock tickers."""
38
-
39
  symbols: List[str] = Field(description="List of stock tickers.")
40
 
41
  class RoutePrediction(BaseModel):
42
- """Tools list"""
43
  tools: List[Literal["financial_rag", "market_data", "general_chat"]] = Field(description="Tools list")
44
 
45
  # --- 3. CACHED INITIALIZATION ---
46
  @st.cache_resource(show_spinner=False)
47
  def initialize_resources():
48
- print("๐Ÿ”Œ Initializing Agent...")
49
-
50
  Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0)
51
  Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
52
 
53
- # --- CSV PATH FINDER ---
54
- # We check ALL possible locations
55
  possible_paths = [
56
- "nasdaq-listed.csv", # Root directory
57
- "src/nasdaq-listed.csv", # Src folder
58
- os.path.join(os.getcwd(), "nasdaq-listed.csv"), # Current Working Directory
59
- os.path.join(os.path.dirname(__file__), "nasdaq-listed.csv"), # Same folder as script
60
- "../nasdaq-listed.csv" # One level up
61
  ]
 
62
 
63
- csv_path = None
64
- for path in possible_paths:
65
- if os.path.exists(path):
66
- csv_path = path
67
- print(f"โœ… Found CSV at: {path}")
68
- break
69
-
70
  if csv_path:
71
- try:
72
- nasdaq_df = pd.read_csv(csv_path)
73
- nasdaq_df.columns = [c.strip() for c in nasdaq_df.columns]
74
- except Exception as e:
75
- st.error(f"CSV Corrupt: {e}")
76
- nasdaq_df = pd.DataFrame()
77
  else:
78
- st.error(f"โŒ CRITICAL: 'nasdaq-listed.csv' not found. I looked in: {possible_paths}")
79
  nasdaq_df = pd.DataFrame()
80
 
81
- # --- Connect to Pinecone ---
82
  try:
83
  api_key = os.environ.get("PINECONE_API_KEY")
84
  if not api_key: raise ValueError("Pinecone Key Missing")
85
-
86
  pc = Pinecone(api_key=api_key)
87
  index = VectorStoreIndex.from_vector_store(
88
  vector_store=PineconeVectorStore(pinecone_index=pc.Index("financial-rag-agent"))
89
  )
90
- except Exception as e:
91
- st.error(f"Pinecone Error: {e}")
92
- return nasdaq_df, None
93
 
94
  return nasdaq_df, index
95
 
96
- # --- 4. HELPER FUNCTIONS (From your snippet) ---
97
  def get_symbol_from_csv(query_str: str, df) -> Optional[str]:
98
  if df.empty: return None
99
  query_str = query_str.strip().upper()
@@ -115,7 +108,7 @@ def get_tickers_from_query(query: str, index, df) -> List[str]:
115
  if not ticker and len(entity) <= 5: ticker = entity.upper()
116
  if ticker: valid_tickers.append(ticker)
117
 
118
- if not valid_tickers:
119
  try:
120
  nodes = index.as_retriever(similarity_top_k=1).retrieve(query)
121
  if nodes and nodes[0].metadata.get("ticker"):
@@ -123,7 +116,7 @@ def get_tickers_from_query(query: str, index, df) -> List[str]:
123
  except: pass
124
  return list(set(valid_tickers))
125
 
126
- # --- 5. TOOLS (From your snippet) ---
127
  def get_market_data(query: str, index, df):
128
  tickers = get_tickers_from_query(query, index, df)
129
  if not tickers: return "No companies found."
@@ -138,9 +131,7 @@ def get_market_data(query: str, index, df):
138
  "Market Cap": info.get('marketCap', 'N/A'),
139
  "PE Ratio": info.get('trailingPE', 'N/A'),
140
  "52w High": info.get('fiftyTwoWeekHigh', 'N/A'),
141
- "52w Low": info.get('fiftyTwoWeekLow', 'N/A'),
142
  "Volume": info.get('volume', 'N/A'),
143
- "Currency": info.get('currency', 'USD')
144
  }
145
  results.append(str(data))
146
  except Exception as e:
@@ -158,7 +149,6 @@ def get_financial_rag(query: str, index, df):
158
  continue
159
 
160
  filters = MetadataFilters(filters=[ExactMatchFilter(key="ticker", value=ticker)])
161
- # Using logic from your snippet (similarity_top_k=3)
162
  engine = index.as_query_engine(similarity_top_k=3, filters=filters)
163
  resp = engine.query(query)
164
 
@@ -169,29 +159,15 @@ def get_financial_rag(query: str, index, df):
169
 
170
  return payload
171
 
172
- # --- 6. AGENT LOGIC (From your snippet) ---
173
  def run_agent(user_query: str, index, df) -> AgentResponse:
174
- # THE STRICT PROMPT YOU PROVIDED
175
  router_prompt = """
176
  Route the user query to the correct tool based on these strict definitions:
177
-
178
- 1. "financial_rag":
179
- - Use for ANY question about a specific company's internal details.
180
- - INCLUDES: Revenue, Profit, Income, CEO, Board Members, Risks, Strategy, Competitors, Legal Issues, History.
181
- - Key Trigger: If the answer would be found in a PDF report or Wikipedia page, use this.
182
-
183
- 2. "market_data":
184
- - Use ONLY for Real-Time Trading Metrics.
185
- - INCLUDES: Current Price, Market Cap, PE Ratio, Trading Volume, 52-Week High/Low.
186
- - EXCLUDES: Historical revenue or annual profit (Use financial_rag for those).
187
-
188
- 3. "general_chat":
189
- - Use ONLY for non-business questions (e.g. "Hi", "Help").
190
- - NEVER use this if a specific company (Tesla, Apple, Nvidia) is mentioned.
191
-
192
  Query: {query_str}
193
  """
194
-
195
  router = OpenAIPydanticProgram.from_defaults(
196
  output_cls=RoutePrediction,
197
  prompt_template_str=router_prompt,
@@ -216,47 +192,69 @@ def run_agent(user_query: str, index, df) -> AgentResponse:
216
  context_used.extend(res["raw_nodes"])
217
 
218
  final_prompt = f"""
219
- You are a Wall Street Financial Analyst. Answer the user request using the provided context.
220
-
221
- Context Data:
222
- {results}
223
-
224
  Instructions:
225
  1. Compare Metrics if multiple companies are listed.
226
  2. Synthesize qualitative (Risks) and quantitative (Price) data.
227
- 3. Explicitly state if a report is missing.
228
- 4. Cite sources.
229
-
230
  User Query: {user_query}
231
  """
232
-
233
  response_text = Settings.llm.complete(final_prompt).text
234
-
235
- return AgentResponse(
236
- answer=response_text,
237
- sources=list(set(sources)),
238
- context_used=context_used
239
- )
240
 
241
- # --- 7. STREAMLIT UI ---
242
- # Initialize Logic
243
  with st.sidebar:
244
- st.title("๐Ÿ”ง System Status")
245
- with st.spinner("Initializing Strict-Boundary Agent..."):
246
- try:
247
- nasdaq_df, pinecone_index = initialize_resources()
248
- st.success("โœ… Brain Loaded")
249
- st.success(f"โœ… {len(nasdaq_df)} Tickers Indexed")
250
- except Exception as e:
251
- st.error(f"Initialization Failed: {e}")
252
- st.stop()
253
-
 
 
 
 
 
 
254
  st.markdown("---")
255
- st.markdown("### ๐ŸŽฏ RAG Coverage")
256
- st.code("AAPL\nTSLA\nNVDA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
- st.title("๐Ÿ“ˆ Financial Agent (Strict Logic)")
 
 
 
 
 
 
 
 
 
259
 
 
260
  if "messages" not in st.session_state:
261
  st.session_state.messages = []
262
 
@@ -265,40 +263,49 @@ for message in st.session_state.messages:
265
  with st.chat_message(message["role"]):
266
  st.markdown(message["content"])
267
  if "sources" in message:
268
- with st.expander("๐Ÿ“š Sources & Context"):
269
  st.write(message["sources"])
270
- for i, c in enumerate(message["context"][:3]): # Limit preview
271
- st.text(f"Snippet {i+1}: {str(c)[:300]}...")
 
 
272
 
273
- # Input Handler
274
- if prompt := st.chat_input("Enter query..."):
275
- st.session_state.messages.append({"role": "user", "content": prompt})
 
 
 
276
  with st.chat_message("user"):
277
- st.markdown(prompt)
278
 
279
  with st.chat_message("assistant"):
280
- with st.status("๐Ÿง  Analyst is thinking...", expanded=True) as status:
 
281
  try:
282
- # RUN THE SAVED LOGIC
283
- response = run_agent(prompt, pinecone_index, nasdaq_df)
284
-
285
- status.update(label="โœ… Complete", state="complete", expanded=False)
286
- st.markdown(response.answer)
287
-
288
- # Audit Trail
289
- with st.expander("๐Ÿ” Audit Trail (Full Context)"):
290
- st.write("**Sources:**", response.sources)
291
- st.write("**Raw Retrieval:**")
292
- for ctx in response.context_used:
293
- st.text(str(ctx))
294
-
295
- st.session_state.messages.append({
296
- "role": "assistant",
297
- "content": response.answer,
298
- "sources": response.sources,
299
- "context": response.context_used
300
- })
301
-
302
  except Exception as e:
303
  st.error(f"Error: {e}")
304
- status.update(label="โŒ Error", state="error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from llama_index.llms.openai import OpenAI
13
  from llama_index.core.vector_stores import MetadataFilters, ExactMatchFilter
14
 
15
+ # --- 1. PAGE CONFIGURATION ---
16
+ st.set_page_config(
17
+ page_title="Wall St. AI Analyst",
18
+ page_icon="๐Ÿ›๏ธ",
19
+ layout="wide",
20
+ initial_sidebar_state="expanded"
21
+ )
22
+
23
+ # Custom CSS for a cleaner look
24
+ st.markdown("""
25
+ <style>
26
+ .stButton>button {
27
+ width: 100%;
28
+ border-radius: 5px;
29
+ height: 3em;
30
+ background-color: #f0f2f6;
31
+ }
32
+ .reportview-container {
33
+ background: #ffffff;
34
+ }
35
+ </style>
36
+ """, unsafe_allow_html=True)
37
 
38
  # Ensure keys exist
39
  if "OPENAI_API_KEY" not in os.environ:
40
+ st.error("โŒ OPENAI_API_KEY missing. Please check Space Settings.")
41
  st.stop()
42
 
43
+ # --- 2. DATA MODELS ---
44
  class AgentResponse(BaseModel):
45
+ answer: str
46
+ sources: List[str]
47
+ context_used: List[str]
 
 
 
 
 
 
 
48
 
49
  class TickerExtraction(BaseModel):
 
 
50
  symbols: List[str] = Field(description="List of stock tickers.")
51
 
52
  class RoutePrediction(BaseModel):
 
53
  tools: List[Literal["financial_rag", "market_data", "general_chat"]] = Field(description="Tools list")
54
 
55
  # --- 3. CACHED INITIALIZATION ---
56
  @st.cache_resource(show_spinner=False)
57
  def initialize_resources():
 
 
58
  Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0)
59
  Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
60
 
61
+ # Locate CSV
 
62
  possible_paths = [
63
+ "nasdaq-listed.csv", "src/nasdaq-listed.csv",
64
+ os.path.join(os.getcwd(), "nasdaq-listed.csv"),
65
+ os.path.join(os.path.dirname(__file__), "nasdaq-listed.csv"),
66
+ "../nasdaq-listed.csv"
 
67
  ]
68
+ csv_path = next((p for p in possible_paths if os.path.exists(p)), None)
69
 
 
 
 
 
 
 
 
70
  if csv_path:
71
+ nasdaq_df = pd.read_csv(csv_path)
72
+ nasdaq_df.columns = [c.strip() for c in nasdaq_df.columns]
 
 
 
 
73
  else:
 
74
  nasdaq_df = pd.DataFrame()
75
 
76
+ # Connect to Pinecone
77
  try:
78
  api_key = os.environ.get("PINECONE_API_KEY")
79
  if not api_key: raise ValueError("Pinecone Key Missing")
 
80
  pc = Pinecone(api_key=api_key)
81
  index = VectorStoreIndex.from_vector_store(
82
  vector_store=PineconeVectorStore(pinecone_index=pc.Index("financial-rag-agent"))
83
  )
84
+ except:
85
+ index = None
 
86
 
87
  return nasdaq_df, index
88
 
89
+ # --- 4. HELPER FUNCTIONS ---
90
  def get_symbol_from_csv(query_str: str, df) -> Optional[str]:
91
  if df.empty: return None
92
  query_str = query_str.strip().upper()
 
108
  if not ticker and len(entity) <= 5: ticker = entity.upper()
109
  if ticker: valid_tickers.append(ticker)
110
 
111
+ if not valid_tickers and index:
112
  try:
113
  nodes = index.as_retriever(similarity_top_k=1).retrieve(query)
114
  if nodes and nodes[0].metadata.get("ticker"):
 
116
  except: pass
117
  return list(set(valid_tickers))
118
 
119
+ # --- 5. TOOLS ---
120
  def get_market_data(query: str, index, df):
121
  tickers = get_tickers_from_query(query, index, df)
122
  if not tickers: return "No companies found."
 
131
  "Market Cap": info.get('marketCap', 'N/A'),
132
  "PE Ratio": info.get('trailingPE', 'N/A'),
133
  "52w High": info.get('fiftyTwoWeekHigh', 'N/A'),
 
134
  "Volume": info.get('volume', 'N/A'),
 
135
  }
136
  results.append(str(data))
137
  except Exception as e:
 
149
  continue
150
 
151
  filters = MetadataFilters(filters=[ExactMatchFilter(key="ticker", value=ticker)])
 
152
  engine = index.as_query_engine(similarity_top_k=3, filters=filters)
153
  resp = engine.query(query)
154
 
 
159
 
160
  return payload
161
 
162
+ # --- 6. AGENT LOGIC ---
163
  def run_agent(user_query: str, index, df) -> AgentResponse:
 
164
  router_prompt = """
165
  Route the user query to the correct tool based on these strict definitions:
166
+ 1. "financial_rag": Company internal details (Revenue, Risks, Strategy, CEO).
167
+ 2. "market_data": Real-Time Trading Metrics (Price, PE, Volume) ONLY.
168
+ 3. "general_chat": Non-business questions.
 
 
 
 
 
 
 
 
 
 
 
 
169
  Query: {query_str}
170
  """
 
171
  router = OpenAIPydanticProgram.from_defaults(
172
  output_cls=RoutePrediction,
173
  prompt_template_str=router_prompt,
 
192
  context_used.extend(res["raw_nodes"])
193
 
194
  final_prompt = f"""
195
+ You are a Wall Street Financial Analyst. Answer using the provided context.
196
+ Context Data: {results}
 
 
 
197
  Instructions:
198
  1. Compare Metrics if multiple companies are listed.
199
  2. Synthesize qualitative (Risks) and quantitative (Price) data.
200
+ 3. Cite sources.
 
 
201
  User Query: {user_query}
202
  """
 
203
  response_text = Settings.llm.complete(final_prompt).text
204
+ return AgentResponse(answer=response_text, sources=list(set(sources)), context_used=context_used)
 
 
 
 
 
205
 
206
+ # --- 7. UI LOGIC ---
 
207
  with st.sidebar:
208
+ st.image("https://img.icons8.com/color/96/000000/bullish.png", width=80)
209
+ st.title("System Status")
210
+
211
+ with st.spinner("Connecting to Wall St..."):
212
+ nasdaq_df, pinecone_index = initialize_resources()
213
+
214
+ if not nasdaq_df.empty:
215
+ st.success(f"โœ… Market Data: {len(nasdaq_df):,} Tickers")
216
+ else:
217
+ st.warning("โš ๏ธ Market Data: Offline")
218
+
219
+ if pinecone_index:
220
+ st.success("โœ… Knowledge Base: Online")
221
+ else:
222
+ st.error("โŒ Knowledge Base: Offline")
223
+
224
  st.markdown("---")
225
+ st.markdown("### ๐Ÿง  Capabilities")
226
+
227
+ st.info("**Deep Dive (10-K Reports)**")
228
+ st.markdown("- ๐ŸŽ Apple (AAPL)\n- ๐Ÿš— Tesla (TSLA)\n- ๐ŸŽฎ Nvidia (NVDA)")
229
+ st.caption("*Ask about Strategy, Risks, Revenue*")
230
+
231
+ st.info("**Live Market Data**")
232
+ st.markdown("- ๐ŸŒ All NASDAQ Companies")
233
+ st.caption("*Ask about Price, PE Ratio, Volume*")
234
+
235
+ st.markdown("---")
236
+ if st.button("๐Ÿงน Clear Conversation"):
237
+ st.session_state.messages = []
238
+ st.rerun()
239
+
240
+ # Main Hero Section
241
+ st.title("๐Ÿ›๏ธ Wall St. AI Analyst")
242
+ st.markdown("""
243
+ **Your Hybrid Financial Assistant.** I bridge the gap between **Real-Time Market Data** and **Deep 10-K Analysis**.
244
+ """)
245
 
246
+ # Quick Start Buttons
247
+ col1, col2, col3 = st.columns(3)
248
+ if col1.button("๐Ÿ†š Compare Risks"):
249
+ prompt = "Compare the supply chain risks of Apple and Tesla."
250
+ elif col2.button("๐Ÿ“Š Apple vs Nvidia Revenue"):
251
+ prompt = "Compare the revenue growth of Apple and Nvidia."
252
+ elif col3.button("๐Ÿ“ˆ Tesla PE & Price"):
253
+ prompt = "What is the current price and PE ratio of Tesla?"
254
+ else:
255
+ prompt = None
256
 
257
+ # Chat State
258
  if "messages" not in st.session_state:
259
  st.session_state.messages = []
260
 
 
263
  with st.chat_message(message["role"]):
264
  st.markdown(message["content"])
265
  if "sources" in message:
266
+ with st.expander("๐Ÿ“š Data Sources & Citations"):
267
  st.write(message["sources"])
268
+ st.divider()
269
+ for i, c in enumerate(message["context"][:2]):
270
+ st.caption(f"**Context Fragment {i+1}:**")
271
+ st.text(str(c)[:500] + "...")
272
 
273
+ # Handle Input (Button or Text)
274
+ if user_input := st.chat_input("Ask a financial question...") or prompt:
275
+ # If button was clicked, override text input
276
+ final_query = prompt if prompt else user_input
277
+
278
+ st.session_state.messages.append({"role": "user", "content": final_query})
279
  with st.chat_message("user"):
280
+ st.markdown(final_query)
281
 
282
  with st.chat_message("assistant"):
283
+ # Status container (collapsible)
284
+ with st.status("๐Ÿง  Analyzing 10-Ks and Market Data...", expanded=True) as status:
285
  try:
286
+ response = run_agent(final_query, pinecone_index, nasdaq_df)
287
+ status.update(label="โœ… Analysis Complete", state="complete", expanded=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  except Exception as e:
289
  st.error(f"Error: {e}")
290
+ status.update(label="โŒ Error", state="error")
291
+ st.stop()
292
+
293
+ # ANSWER DISPLAY (Now OUTSIDE the status block so it auto-shows)
294
+ st.markdown(response.answer)
295
+
296
+ # Sources (Collapsible)
297
+ with st.expander("๐Ÿ” Audit Trail (Read the Source Data)"):
298
+ st.markdown("### ๐Ÿ“š Cited Sources")
299
+ st.write(response.sources)
300
+ st.divider()
301
+ st.markdown("### ๐Ÿ“„ Raw Context Snippets")
302
+ for ctx in response.context_used:
303
+ st.text(str(ctx))
304
+
305
+ # Save to history
306
+ st.session_state.messages.append({
307
+ "role": "assistant",
308
+ "content": response.answer,
309
+ "sources": response.sources,
310
+ "context": response.context_used
311
+ })