AJAY KASU commited on
Commit
2750cce
·
1 Parent(s): b14afd7

Refactor: Replace regex sector exclusion with LLM-based Intent Parser

Browse files
Files changed (7) hide show
  1. ai/ai_reporter.py +33 -0
  2. ai/prompts.py +30 -1
  3. api/app.py +11 -0
  4. core/schema.py +1 -0
  5. data/optimizer.py +3 -11
  6. main.py +6 -0
  7. streamlit_app.py +3 -32
ai/ai_reporter.py CHANGED
@@ -77,3 +77,36 @@ INSTRUCTION: Start your commentary exactly with the header: "Market Commentary -
77
  except Exception as e:
78
  logger.error(f"Failed to generate AI report: {e}")
79
  return "Error generating commentary. Please check API connection."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
  logger.error(f"Failed to generate AI report: {e}")
79
  return "Error generating commentary. Please check API connection."
80
+ def parse_intent(self, user_prompt: str) -> list:
81
+ """
82
+ Uses LLM to map user prompt to a list of exact GICS sectors to exclude.
83
+ """
84
+ if not self.client:
85
+ logger.warning("LLM Client unavailable for Intent Parsing. Falling back to empty list.")
86
+ return []
87
+
88
+ from ai.prompts import INTENT_PARSER_SYSTEM_PROMPT
89
+
90
+ try:
91
+ response = self.client.chat_completion(
92
+ model=self.model_id,
93
+ messages=[
94
+ {"role": "system", "content": INTENT_PARSER_SYSTEM_PROMPT},
95
+ {"role": "user", "content": f"Parse this prompt for sector exclusions: '{user_prompt}'"}
96
+ ],
97
+ max_tokens=100,
98
+ temperature=0.0 # Strict output
99
+ )
100
+
101
+ content = response.choices[0].message.content.strip()
102
+ # Find the JSON list in the response
103
+ import re
104
+ match = re.search(r'\[.*\]', content, re.DOTALL)
105
+ if match:
106
+ import json
107
+ return json.loads(match.group(0))
108
+ return []
109
+
110
+ except Exception as e:
111
+ logger.error(f"Intent Parsing failed: {e}")
112
+ return []
ai/prompts.py CHANGED
@@ -1,5 +1,34 @@
1
  # System Prompt for the Portfolio Manager Persona
2
- # System Prompt for the Portfolio Manager Persona
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  SYSTEM_PROMPT = """You are a Senior Portfolio Manager at a top-tier Asset Management firm (e.g., Goldman Sachs, BlackRock).
4
  Your goal is to write a concise, professional, and insightful performance commentary for a High Net Worth Application.
5
  Your tone should be:
 
1
  # System Prompt for the Portfolio Manager Persona
2
+ # System Prompt for the Intent Parser
3
+ INTENT_PARSER_SYSTEM_PROMPT = """You are a financial data parser.
4
+ Your task is to identify which of the following 11 GICS sectors a user wants to EXCLUDE from their portfolio based on their prompt.
5
+
6
+ GICS Sectors:
7
+ 1. Information Technology
8
+ 2. Health Care
9
+ 3. Financials
10
+ 4. Consumer Discretionary
11
+ 5. Communication Services
12
+ 6. Industrials
13
+ 7. Consumer Staples
14
+ 8. Energy
15
+ 9. Utilities
16
+ 10. Real Estate
17
+ 11. Materials
18
+
19
+ ## RULES:
20
+ 1. Return ONLY a valid JSON list of strings from the 11 GICS sectors above.
21
+ 2. If the user mentions "tech", map it to "Information Technology".
22
+ 3. If the user mentions "banks" or "finance", map it to "Financials".
23
+ 4. If the user mentions "healthcare" or "pharma", map it to "Health Care".
24
+ 5. If the user doesn't want to exclude any sectors, return [].
25
+ 6. Do NOT include any explanations or extra text.
26
+
27
+ Example:
28
+ User: "no tech or banks"
29
+ Output: ["Information Technology", "Financials"]
30
+ """
31
+
32
  SYSTEM_PROMPT = """You are a Senior Portfolio Manager at a top-tier Asset Management firm (e.g., Goldman Sachs, BlackRock).
33
  Your goal is to write a concise, professional, and insightful performance commentary for a High Net Worth Application.
34
  Your tone should be:
api/app.py CHANGED
@@ -26,12 +26,23 @@ def root():
26
  def health_check():
27
  return {"status": "healthy", "service": "QuantScale AI Direct Indexing"}
28
 
 
 
 
 
 
 
 
29
  @app.post("/optimize", response_model=dict)
30
  def optimize_portfolio(request: OptimizationRequest):
31
  """
32
  Optimizes a portfolio based on exclusions and generates an AI Attribution report.
33
  """
34
  try:
 
 
 
 
35
  result = system.run_pipeline(request)
36
  if not result:
37
  raise HTTPException(status_code=500, detail="Pipeline failed to execute.")
 
26
  def health_check():
27
  return {"status": "healthy", "service": "QuantScale AI Direct Indexing"}
28
 
29
+ def parse_constraints_with_llm(user_prompt: str) -> list:
30
+ """
31
+ Dedicated parser function in the API layer.
32
+ Maps natural language to exact GICS sectors.
33
+ """
34
+ return system.ai_reporter.parse_intent(user_prompt)
35
+
36
  @app.post("/optimize", response_model=dict)
37
  def optimize_portfolio(request: OptimizationRequest):
38
  """
39
  Optimizes a portfolio based on exclusions and generates an AI Attribution report.
40
  """
41
  try:
42
+ # If the request contains a raw prompt but no sectors, parse it here
43
+ if request.user_prompt and not request.excluded_sectors:
44
+ request.excluded_sectors = parse_constraints_with_llm(request.user_prompt)
45
+
46
  result = system.run_pipeline(request)
47
  if not result:
48
  raise HTTPException(status_code=500, detail="Pipeline failed to execute.")
core/schema.py CHANGED
@@ -30,6 +30,7 @@ class OptimizationRequest(BaseModel):
30
  strategy: Optional[str] = Field(None, description="Global Filter Strategy: 'smallest_market_cap' or 'largest_market_cap'")
31
  top_n: Optional[int] = Field(None, description="Number of assets to select for strategy (e.g. 50)")
32
  benchmark: str = "^GSPC"
 
33
 
34
  class Config:
35
  json_schema_extra = {
 
30
  strategy: Optional[str] = Field(None, description="Global Filter Strategy: 'smallest_market_cap' or 'largest_market_cap'")
31
  top_n: Optional[int] = Field(None, description="Number of assets to select for strategy (e.g. 50)")
32
  benchmark: str = "^GSPC"
33
+ user_prompt: Optional[str] = Field(None, description="Raw user input for LLM intent parsing")
34
 
35
  class Config:
36
  json_schema_extra = {
data/optimizer.py CHANGED
@@ -74,17 +74,9 @@ class PortfolioOptimizer:
74
  logger.info(f"Applying Sector Exclusion Validation for: {excluded_sectors}")
75
  for i, ticker in enumerate(tickers):
76
  sector = sector_map.get(ticker, "Unknown")
77
- # Normalize both for robust matching (e.g., "Health Care" vs "Healthcare")
78
- sector_norm = sector.lower().replace(" ", "").replace("-", "")
79
-
80
- for excl in excluded_sectors:
81
- excl_norm = excl.lower().replace(" ", "").replace("-", "")
82
-
83
- # Match if normalized strings are equal OR special mapping for Tech
84
- if excl_norm == sector_norm or (excl_norm == "tech" and sector_norm == "informationtechnology"):
85
- excluded_indices.append(i)
86
- mask_vector[i] = 1
87
- break
88
 
89
  # Ticker Exclusions (NEW)
90
  if excluded_tickers:
 
74
  logger.info(f"Applying Sector Exclusion Validation for: {excluded_sectors}")
75
  for i, ticker in enumerate(tickers):
76
  sector = sector_map.get(ticker, "Unknown")
77
+ if sector in excluded_sectors:
78
+ excluded_indices.append(i)
79
+ mask_vector[i] = 1
 
 
 
 
 
 
 
 
80
 
81
  # Ticker Exclusions (NEW)
82
  if excluded_tickers:
main.py CHANGED
@@ -27,6 +27,12 @@ class QuantScaleSystem:
27
  def run_pipeline(self, request: OptimizationRequest):
28
  logger.info(f"Starting pipeline for Client {request.client_id}...")
29
 
 
 
 
 
 
 
30
  # 1. Fetch Universe (S&P 500)
31
  tickers = self.data_engine.fetch_sp500_tickers()
32
 
 
27
  def run_pipeline(self, request: OptimizationRequest):
28
  logger.info(f"Starting pipeline for Client {request.client_id}...")
29
 
30
+ # 0. LLM Intent Parsing (New)
31
+ if request.user_prompt and not request.excluded_sectors:
32
+ logger.info(f"Parsing user intent: '{request.user_prompt}'")
33
+ request.excluded_sectors = self.ai_reporter.parse_intent(request.user_prompt)
34
+ logger.info(f"LLM Mapped Exclusions: {request.excluded_sectors}")
35
+
36
  # 1. Fetch Universe (S&P 500)
37
  tickers = self.data_engine.fetch_sp500_tickers()
38
 
streamlit_app.py CHANGED
@@ -61,21 +61,6 @@ st.markdown("""
61
  </style>
62
  """, unsafe_allow_html=True)
63
 
64
- # --- Constants ---
65
- SECTOR_KEYWORDS = {
66
- "Energy": ["energy", "oil", "gas"],
67
- "Technology": ["technology", "tech", "software", "it"],
68
- "Financials": ["financials", "finance", "banks"],
69
- "Healthcare": ["healthcare", "health", "pharma"],
70
- "Utilities": ["utilities", "utility"],
71
- "Materials": ["materials", "mining"],
72
- "Consumer Discretionary": ["consumer", "retail", "discretionary"],
73
- "Real Estate": ["real estate", "reit"],
74
- "Communication Services": ["communication", "media", "telecom"]
75
- }
76
- INCLUDE_KEYWORDS = ["keep", "include", "with", "stay", "portfolio", "only"]
77
-
78
-
79
  # --- Parsers ---
80
  def parse_investment_amount(text: str) -> float:
81
  text = text.replace(",", "")
@@ -89,20 +74,6 @@ def parse_investment_amount(text: str) -> float:
89
  return 100_000.0
90
 
91
 
92
- def parse_excluded_sectors(text: str) -> list:
93
- lower = text.lower()
94
- excluded = []
95
- for sector, keywords in SECTOR_KEYWORDS.items():
96
- if any(k in lower for k in keywords):
97
- inc_pattern = re.compile(
98
- rf'({"|".join(INCLUDE_KEYWORDS)})\s+(the\s+)?({"|".join([sector.lower()] + keywords)})',
99
- re.IGNORECASE
100
- )
101
- if not inc_pattern.search(lower):
102
- excluded.append(sector)
103
- return excluded
104
-
105
-
106
  def parse_strategy(text: str):
107
  lower = text.lower()
108
  strategy, top_n = None, None
@@ -148,17 +119,17 @@ run_btn = st.button("🚀 Generate Portfolio Strategy", use_container_width=True
148
 
149
  if run_btn and user_input:
150
  investment_amount = parse_investment_amount(user_input)
151
- excluded_sectors = parse_excluded_sectors(user_input)
152
  strategy, top_n = parse_strategy(user_input)
153
 
154
  request = OptimizationRequest(
155
  client_id="StreamlitUser",
156
  initial_investment=investment_amount,
157
- excluded_sectors=excluded_sectors,
158
  excluded_tickers=[],
159
  strategy=strategy,
160
  top_n=top_n,
161
- benchmark="^GSPC"
 
162
  )
163
 
164
  with st.spinner("⚙️ Running Convex Optimization & AI Analysis..."):
 
61
  </style>
62
  """, unsafe_allow_html=True)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # --- Parsers ---
65
  def parse_investment_amount(text: str) -> float:
66
  text = text.replace(",", "")
 
74
  return 100_000.0
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def parse_strategy(text: str):
78
  lower = text.lower()
79
  strategy, top_n = None, None
 
119
 
120
  if run_btn and user_input:
121
  investment_amount = parse_investment_amount(user_input)
 
122
  strategy, top_n = parse_strategy(user_input)
123
 
124
  request = OptimizationRequest(
125
  client_id="StreamlitUser",
126
  initial_investment=investment_amount,
127
+ excluded_sectors=[], # Let the LLM derive this
128
  excluded_tickers=[],
129
  strategy=strategy,
130
  top_n=top_n,
131
+ benchmark="^GSPC",
132
+ user_prompt=user_input
133
  )
134
 
135
  with st.spinner("⚙️ Running Convex Optimization & AI Analysis..."):