ayush2917 commited on
Commit
8483594
·
verified ·
1 Parent(s): 401dbae

Update agent/planner.py

Browse files
Files changed (1) hide show
  1. agent/planner.py +66 -31
agent/planner.py CHANGED
@@ -1,9 +1,19 @@
1
  import logging
 
2
 
3
  logger = logging.getLogger(__name__)
4
 
5
  class QueryPlanner:
6
- def plan_query(self, query: str) -> list:
 
 
 
 
 
 
 
 
 
7
  """
8
  Decompose a user query into a list of planned tasks (e.g., news search, stock price lookup).
9
  This planning logic is heuristic-based and can be expanded for more sophisticated query understanding.
@@ -15,21 +25,36 @@ class QueryPlanner:
15
 
16
  # --- News Search Task Identification ---
17
  # Keywords indicating a news-related query.
18
- news_keywords = [
19
- "trend", "news", "articles", "latest", "updates",
20
- "market sentiment", "recent", "headlines", "report"
 
21
  ]
22
 
23
- # If any news keyword is present, plan a news search task.
24
- # The original query is passed to the news search tool.
25
- if any(keyword in query_lower for keyword in news_keywords):
26
- # We can refine this to extract specific news categories if the query gets more complex
27
- tasks.append({"type": "news_search", "query": query, "category": None})
28
- logger.debug(f"Identified news search task for query: '{query}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # --- Stock Price Task Identification ---
31
- # A predefined list of common stock symbols or company names to look for.
32
- # This list can be expanded or dynamically loaded from a financial data source.
 
33
  known_stock_symbols = [
34
  "nifty", "reliance", "tcs", "infy", "hdfcbank", "sbin",
35
  "icicibank", "bhartiartl", "kotakbank", "lt", "asianpaint",
@@ -39,32 +64,27 @@ class QueryPlanner:
39
  ]
40
 
41
  found_symbols = []
42
- # Iterate through known symbols to find exact matches or strong indicators in the query.
43
  for symbol in known_stock_symbols:
44
- # Use word boundaries (spaces) to avoid partial matches (e.g., "fin" matching "bajajfinserv")
45
- # Also handle cases where symbol is at the start or end of the query, or is the entire query.
46
- if (f" {symbol.lower()} " in f" {query_lower} " or
47
- query_lower.startswith(f"{symbol.lower()} ") or
48
- query_lower.endswith(f" {symbol.lower()}") or
49
- query_lower == symbol.lower()):
50
-
51
- # Add the symbol in uppercase to ensure consistency for API calls,
52
- # but only if it hasn't been added already.
53
- if symbol.upper() not in [task.get("symbol") for task in tasks if task["type"] == "stock_price"]:
54
  found_symbols.append(symbol.upper())
55
  logger.debug(f"Found explicit stock symbol '{symbol.upper()}' in query.")
56
 
57
  # Fallback heuristic: If no known symbol is found but keywords like "stock" or "price" are present,
58
- # try to infer a symbol from the last word of the query.
59
- if not found_symbols and ("stock" in query_lower or "price" in query_lower or "quote" in query_lower):
60
  words = query_lower.split()
61
- if words:
62
- potential_symbol = words[-1].upper()
63
- # Basic validation for inferred symbol: 2-5 characters, all alphabetic.
64
- if 2 <= len(potential_symbol) <= 5 and potential_symbol.isalpha():
65
- if potential_symbol not in [task.get("symbol") for task in tasks if task["type"] == "stock_price"]:
66
  found_symbols.append(potential_symbol)
67
- logger.info(f"Inferred potential stock symbol '{potential_symbol}' from query tail.")
 
68
 
69
  # Add a stock_price task for each identified symbol.
70
  for symbol in found_symbols:
@@ -75,3 +95,18 @@ class QueryPlanner:
75
  except Exception as e:
76
  logger.error(f"Error planning query '{query}': {e}", exc_info=True)
77
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import List, Dict, Set
3
 
4
  logger = logging.getLogger(__name__)
5
 
6
  class QueryPlanner:
7
+ def __init__(self, india_finance_keywords: List[str], category_keywords: Dict[str, Set[str]]):
8
+ """
9
+ Initialize QueryPlanner with keyword lists for financial topics and categories.
10
+ These lists are typically provided by NewsTools or a central configuration.
11
+ """
12
+ self.india_finance_keywords = set(india_finance_keywords) # Convert to set for faster lookup
13
+ self.category_keywords = {k: set(v) for k, v in category_keywords.items()} # Convert values to sets
14
+ logger.info("QueryPlanner initialized with financial keywords and categories.")
15
+
16
+ def plan_query(self, query: str) -> List[Dict]:
17
  """
18
  Decompose a user query into a list of planned tasks (e.g., news search, stock price lookup).
19
  This planning logic is heuristic-based and can be expanded for more sophisticated query understanding.
 
25
 
26
  # --- News Search Task Identification ---
27
  # Keywords indicating a news-related query.
28
+ news_trigger_keywords = [
29
+ "trend", "news", "articles", "latest", "updates", "market sentiment",
30
+ "recent", "headlines", "report", "what's happening", "current situation",
31
+ "developments"
32
  ]
33
 
34
+ # Combine general news triggers with specific financial keywords for news search
35
+ # If any news trigger keyword or general financial keyword is present, plan a news search task.
36
+ is_news_query = any(keyword in query_lower for keyword in news_trigger_keywords)
37
+ is_general_financial_query = any(keyword in query_lower for keyword in self.india_finance_keywords)
38
+
39
+ # Determine primary category for news search if possible
40
+ detected_categories = []
41
+ for category_name, keywords_set in self.category_keywords.items():
42
+ if any(keyword in query_lower for keyword in keywords_set):
43
+ detected_categories.append(category_name)
44
+
45
+ # If it's a news-related query or a general financial query, plan a news search.
46
+ # Prioritize a specific category if detected, otherwise default to a broader search.
47
+ if is_news_query or is_general_financial_query:
48
+ # If multiple categories detected, choose the most relevant or broaden the search
49
+ # For simplicity, we'll pick the first detected or keep it None for broad search.
50
+ primary_category = detected_categories[0] if detected_categories else None
51
+ tasks.append({"type": "news_search", "query": query, "category": primary_category})
52
+ logger.debug(f"Identified news search task for query: '{query}' (Category: {primary_category or 'Any'})")
53
 
54
  # --- Stock Price Task Identification ---
55
+ # Expanded list of common stock symbols or company names (should come from NewsTools/MCPServer for consistency)
56
+ # For now, let's keep a sample, but ideally, this would be dynamic or passed in.
57
+ # NOTE: For better dynamic stock symbol recognition, a more advanced NLP model might be needed.
58
  known_stock_symbols = [
59
  "nifty", "reliance", "tcs", "infy", "hdfcbank", "sbin",
60
  "icicibank", "bhartiartl", "kotakbank", "lt", "asianpaint",
 
64
  ]
65
 
66
  found_symbols = []
67
+ # Check for explicit stock symbol mentions
68
  for symbol in known_stock_symbols:
69
+ # Using regex for word boundary check (more robust than simple space check)
70
+ import re
71
+ if re.search(r'\b' + re.escape(symbol.lower()) + r'\b', query_lower):
72
+ if symbol.upper() not in found_symbols: # Ensure uniqueness
 
 
 
 
 
 
73
  found_symbols.append(symbol.upper())
74
  logger.debug(f"Found explicit stock symbol '{symbol.upper()}' in query.")
75
 
76
  # Fallback heuristic: If no known symbol is found but keywords like "stock" or "price" are present,
77
+ # try to infer a symbol from the last significant word of the query.
78
+ if not found_symbols and any(kw in query_lower for kw in ["stock", "price", "quote", "share"]):
79
  words = query_lower.split()
80
+ # Find the last word that is likely a symbol (alphabetic, short)
81
+ for i in range(len(words) -1, -1, -1):
82
+ potential_symbol = words[i].upper()
83
+ if 2 <= len(potential_symbol) <= 8 and potential_symbol.isalpha(): # Increased max length
84
+ if potential_symbol not in found_symbols:
85
  found_symbols.append(potential_symbol)
86
+ logger.info(f"Inferred potential stock symbol '{potential_symbol}' from query.")
87
+ break # Only infer one for now
88
 
89
  # Add a stock_price task for each identified symbol.
90
  for symbol in found_symbols:
 
95
  except Exception as e:
96
  logger.error(f"Error planning query '{query}': {e}", exc_info=True)
97
  return []
98
+
99
+ def generate_topic_suggestions(self) -> List[str]:
100
+ """
101
+ Generates sample financial topic suggestions for the chatbot.
102
+ Ported from Flask app.
103
+ """
104
+ return [
105
+ "What's today's Nifty 50 trend?",
106
+ "Explain RBI's latest repo rate decision",
107
+ "What's the latest on cryptocurrency in India?",
108
+ "How is the Indian economy doing?",
109
+ "What are the recent government policies affecting the market?",
110
+ "What are the current gold prices in India?",
111
+ "Tell me about recent IPOs."
112
+ ]