Spaces:
Build error
Build error
Update agent/planner.py
Browse files- agent/planner.py +66 -31
agent/planner.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
|
| 3 |
logger = logging.getLogger(__name__)
|
| 4 |
|
| 5 |
class QueryPlanner:
|
| 6 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
Decompose a user query into a list of planned tasks (e.g., news search, stock price lookup).
|
| 9 |
This planning logic is heuristic-based and can be expanded for more sophisticated query understanding.
|
|
@@ -15,21 +25,36 @@ class QueryPlanner:
|
|
| 15 |
|
| 16 |
# --- News Search Task Identification ---
|
| 17 |
# Keywords indicating a news-related query.
|
| 18 |
-
|
| 19 |
-
"trend", "news", "articles", "latest", "updates",
|
| 20 |
-
"
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# --- Stock Price Task Identification ---
|
| 31 |
-
#
|
| 32 |
-
#
|
|
|
|
| 33 |
known_stock_symbols = [
|
| 34 |
"nifty", "reliance", "tcs", "infy", "hdfcbank", "sbin",
|
| 35 |
"icicibank", "bhartiartl", "kotakbank", "lt", "asianpaint",
|
|
@@ -39,32 +64,27 @@ class QueryPlanner:
|
|
| 39 |
]
|
| 40 |
|
| 41 |
found_symbols = []
|
| 42 |
-
#
|
| 43 |
for symbol in known_stock_symbols:
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
if (
|
| 47 |
-
|
| 48 |
-
query_lower.endswith(f" {symbol.lower()}") or
|
| 49 |
-
query_lower == symbol.lower()):
|
| 50 |
-
|
| 51 |
-
# Add the symbol in uppercase to ensure consistency for API calls,
|
| 52 |
-
# but only if it hasn't been added already.
|
| 53 |
-
if symbol.upper() not in [task.get("symbol") for task in tasks if task["type"] == "stock_price"]:
|
| 54 |
found_symbols.append(symbol.upper())
|
| 55 |
logger.debug(f"Found explicit stock symbol '{symbol.upper()}' in query.")
|
| 56 |
|
| 57 |
# Fallback heuristic: If no known symbol is found but keywords like "stock" or "price" are present,
|
| 58 |
-
# try to infer a symbol from the last word of the query.
|
| 59 |
-
if not found_symbols and (
|
| 60 |
words = query_lower.split()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
if 2 <= len(potential_symbol) <=
|
| 65 |
-
if potential_symbol not in
|
| 66 |
found_symbols.append(potential_symbol)
|
| 67 |
-
logger.info(f"Inferred potential stock symbol '{potential_symbol}' from query
|
|
|
|
| 68 |
|
| 69 |
# Add a stock_price task for each identified symbol.
|
| 70 |
for symbol in found_symbols:
|
|
@@ -75,3 +95,18 @@ class QueryPlanner:
|
|
| 75 |
except Exception as e:
|
| 76 |
logger.error(f"Error planning query '{query}': {e}", exc_info=True)
|
| 77 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List, Dict, Set
|
| 3 |
|
| 4 |
logger = logging.getLogger(__name__)
|
| 5 |
|
| 6 |
class QueryPlanner:
|
| 7 |
+
def __init__(self, india_finance_keywords: List[str], category_keywords: Dict[str, Set[str]]):
|
| 8 |
+
"""
|
| 9 |
+
Initialize QueryPlanner with keyword lists for financial topics and categories.
|
| 10 |
+
These lists are typically provided by NewsTools or a central configuration.
|
| 11 |
+
"""
|
| 12 |
+
self.india_finance_keywords = set(india_finance_keywords) # Convert to set for faster lookup
|
| 13 |
+
self.category_keywords = {k: set(v) for k, v in category_keywords.items()} # Convert values to sets
|
| 14 |
+
logger.info("QueryPlanner initialized with financial keywords and categories.")
|
| 15 |
+
|
| 16 |
+
def plan_query(self, query: str) -> List[Dict]:
|
| 17 |
"""
|
| 18 |
Decompose a user query into a list of planned tasks (e.g., news search, stock price lookup).
|
| 19 |
This planning logic is heuristic-based and can be expanded for more sophisticated query understanding.
|
|
|
|
| 25 |
|
| 26 |
# --- News Search Task Identification ---
|
| 27 |
# Keywords indicating a news-related query.
|
| 28 |
+
news_trigger_keywords = [
|
| 29 |
+
"trend", "news", "articles", "latest", "updates", "market sentiment",
|
| 30 |
+
"recent", "headlines", "report", "what's happening", "current situation",
|
| 31 |
+
"developments"
|
| 32 |
]
|
| 33 |
|
| 34 |
+
# Combine general news triggers with specific financial keywords for news search
|
| 35 |
+
# If any news trigger keyword or general financial keyword is present, plan a news search task.
|
| 36 |
+
is_news_query = any(keyword in query_lower for keyword in news_trigger_keywords)
|
| 37 |
+
is_general_financial_query = any(keyword in query_lower for keyword in self.india_finance_keywords)
|
| 38 |
+
|
| 39 |
+
# Determine primary category for news search if possible
|
| 40 |
+
detected_categories = []
|
| 41 |
+
for category_name, keywords_set in self.category_keywords.items():
|
| 42 |
+
if any(keyword in query_lower for keyword in keywords_set):
|
| 43 |
+
detected_categories.append(category_name)
|
| 44 |
+
|
| 45 |
+
# If it's a news-related query or a general financial query, plan a news search.
|
| 46 |
+
# Prioritize a specific category if detected, otherwise default to a broader search.
|
| 47 |
+
if is_news_query or is_general_financial_query:
|
| 48 |
+
# If multiple categories detected, choose the most relevant or broaden the search
|
| 49 |
+
# For simplicity, we'll pick the first detected or keep it None for broad search.
|
| 50 |
+
primary_category = detected_categories[0] if detected_categories else None
|
| 51 |
+
tasks.append({"type": "news_search", "query": query, "category": primary_category})
|
| 52 |
+
logger.debug(f"Identified news search task for query: '{query}' (Category: {primary_category or 'Any'})")
|
| 53 |
|
| 54 |
# --- Stock Price Task Identification ---
|
| 55 |
+
# Expanded list of common stock symbols or company names (should come from NewsTools/MCPServer for consistency)
|
| 56 |
+
# For now, let's keep a sample, but ideally, this would be dynamic or passed in.
|
| 57 |
+
# NOTE: For better dynamic stock symbol recognition, a more advanced NLP model might be needed.
|
| 58 |
known_stock_symbols = [
|
| 59 |
"nifty", "reliance", "tcs", "infy", "hdfcbank", "sbin",
|
| 60 |
"icicibank", "bhartiartl", "kotakbank", "lt", "asianpaint",
|
|
|
|
| 64 |
]
|
| 65 |
|
| 66 |
found_symbols = []
|
| 67 |
+
# Check for explicit stock symbol mentions
|
| 68 |
for symbol in known_stock_symbols:
|
| 69 |
+
# Using regex for word boundary check (more robust than simple space check)
|
| 70 |
+
import re
|
| 71 |
+
if re.search(r'\b' + re.escape(symbol.lower()) + r'\b', query_lower):
|
| 72 |
+
if symbol.upper() not in found_symbols: # Ensure uniqueness
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
found_symbols.append(symbol.upper())
|
| 74 |
logger.debug(f"Found explicit stock symbol '{symbol.upper()}' in query.")
|
| 75 |
|
| 76 |
# Fallback heuristic: If no known symbol is found but keywords like "stock" or "price" are present,
|
| 77 |
+
# try to infer a symbol from the last significant word of the query.
|
| 78 |
+
if not found_symbols and any(kw in query_lower for kw in ["stock", "price", "quote", "share"]):
|
| 79 |
words = query_lower.split()
|
| 80 |
+
# Find the last word that is likely a symbol (alphabetic, short)
|
| 81 |
+
for i in range(len(words) -1, -1, -1):
|
| 82 |
+
potential_symbol = words[i].upper()
|
| 83 |
+
if 2 <= len(potential_symbol) <= 8 and potential_symbol.isalpha(): # Increased max length
|
| 84 |
+
if potential_symbol not in found_symbols:
|
| 85 |
found_symbols.append(potential_symbol)
|
| 86 |
+
logger.info(f"Inferred potential stock symbol '{potential_symbol}' from query.")
|
| 87 |
+
break # Only infer one for now
|
| 88 |
|
| 89 |
# Add a stock_price task for each identified symbol.
|
| 90 |
for symbol in found_symbols:
|
|
|
|
| 95 |
except Exception as e:
|
| 96 |
logger.error(f"Error planning query '{query}': {e}", exc_info=True)
|
| 97 |
return []
|
| 98 |
+
|
| 99 |
+
def generate_topic_suggestions(self) -> List[str]:
|
| 100 |
+
"""
|
| 101 |
+
Generates sample financial topic suggestions for the chatbot.
|
| 102 |
+
Ported from Flask app.
|
| 103 |
+
"""
|
| 104 |
+
return [
|
| 105 |
+
"What's today's Nifty 50 trend?",
|
| 106 |
+
"Explain RBI's latest repo rate decision",
|
| 107 |
+
"What's the latest on cryptocurrency in India?",
|
| 108 |
+
"How is the Indian economy doing?",
|
| 109 |
+
"What are the recent government policies affecting the market?",
|
| 110 |
+
"What are the current gold prices in India?",
|
| 111 |
+
"Tell me about recent IPOs."
|
| 112 |
+
]
|