Abid Ali Awan commited on
Commit
48e84df
·
1 Parent(s): ad9dfde

Refactor RegRadarAgent to simplify parameter extraction by removing LLM function calling. Update UIHandler to pass parameters directly to the regulatory query processing function. Enhance memory saving functionality by implementing background threading for improved performance.

Browse files
Files changed (4) hide show
  1. agents/reg_radar.py +23 -122
  2. agents/ui_handler.py +10 -5
  3. tools/llm.py +0 -26
  4. tools/web_tools.py +25 -38
agents/reg_radar.py CHANGED
@@ -1,7 +1,7 @@
1
  import json
2
  from typing import Dict, Tuple
3
 
4
- from tools.llm import call_llm, call_llm_with_function, stream_llm
5
  from tools.memory_tools import MemoryTools
6
  from tools.web_tools import WebTools
7
 
@@ -29,124 +29,24 @@ class RegRadarAgent:
29
  return "search", "Regulatory Search"
30
 
31
  def extract_parameters(self, message: str) -> Dict:
32
- """Extract industry, region, and keywords from the query using LLM function calling"""
33
- function_schema = {
34
- "name": "extract_parameters",
35
- "description": (
36
- "Extract industry, region, and keywords from a user query.\n"
37
- "- 'industry': If not explicitly mentioned, infer the most relevant industry from the context (e.g., if the query is about SEC regulations, infer 'fintech' or 'finance').\n"
38
- "- 'region': The country or region explicitly mentioned (e.g., US, EU, UK, Asia, Global).\n"
39
- "- 'keywords': Only the most important regulatory topics or terms (e.g., 'data privacy', 'GDPR', 'ESG compliance', 'SEC regulations'), not generic words or verbs.\n"
40
- "Examples:\n"
41
- "- 'Show me the latest SEC regulations for fintech' => industry: 'fintech', region: 'US', keywords: 'SEC regulations'\n"
42
- "- 'What are the new data privacy rules in the EU?' => industry: 'General', region: 'EU', keywords: 'data privacy'\n"
43
- "- 'Scan for healthcare regulations in the US' => industry: 'healthcare', region: 'US', keywords: 'healthcare regulations'\n"
44
- "- 'Any updates on ESG compliance for energy companies?' => industry: 'energy', region: 'US', keywords: 'ESG compliance'\n"
45
- ),
46
- "parameters": {
47
- "type": "object",
48
- "properties": {
49
- "industry": {
50
- "type": "string",
51
- "description": "The industry mentioned or implied in the query (e.g., fintech, healthcare, energy, general).",
52
- },
53
- "region": {
54
- "type": "string",
55
- "description": "The region or country explicitly mentioned in the query (e.g., US, EU, UK, Asia).",
56
- },
57
- "keywords": {
58
- "type": "string",
59
- "description": "A concise list of the most important regulatory topics or terms from the query, separated by commas. Do NOT return the full user question, generic words, or verbs.",
60
- },
61
- },
62
- "required": ["industry", "region", "keywords"],
63
- },
64
- }
65
- params = call_llm_with_function(message, function_schema)
66
- # Fallback: context-aware extraction if LLM fails
67
- if not params or not all(
68
- k in params for k in ("industry", "region", "keywords")
69
- ):
70
- import re
71
-
72
- # Infer industry from context
73
- industry = "General"
74
- industry_map = {
75
- "fintech": ["fintech", "finance", "sec", "bank", "investment"],
76
- "healthcare": ["healthcare", "medical", "pharma", "hospital"],
77
- "energy": ["energy", "oil", "gas", "renewable", "power"],
78
- "technology": ["technology", "tech", "ai", "software", "it", "cyber"],
79
- "retail": ["retail", "ecommerce", "shopping", "store"],
80
- "general": [],
81
- }
82
- for ind, keywords in industry_map.items():
83
- if any(word in message.lower() for word in keywords):
84
- industry = ind
85
- break
86
- # Extract region
87
- region_match = re.search(
88
- r"\b(EU|US|UK|Asia|Europe|America|Canada|Australia|India|China|Japan|Global)\b",
89
- message,
90
- re.IGNORECASE,
91
- )
92
- region = region_match.group(1).upper() if region_match else "US"
93
- # Extract keywords: regulatory terms and meaningful noun phrases only
94
- regulatory_terms = [
95
- "regulation",
96
- "regulations",
97
- "compliance",
98
- "GDPR",
99
- "data privacy",
100
- "SEC",
101
- "ESG",
102
- "law",
103
- "rules",
104
- "requirements",
105
- ]
106
- found_terms = [
107
- term for term in regulatory_terms if term.lower() in message.lower()
108
- ]
109
- # Multi-word capitalized noun phrases (e.g., 'data privacy', 'SEC regulations')
110
- noun_phrases = re.findall(r"([A-Z][a-z]+(?: [a-z]+)+)", message)
111
- # Remove question words and generic words
112
- question_words = {
113
- "what",
114
- "which",
115
- "who",
116
- "whom",
117
- "whose",
118
- "when",
119
- "where",
120
- "why",
121
- "how",
122
- }
123
- generic_words = {
124
- "rules",
125
- "regulation",
126
- "regulations",
127
- "requirement",
128
- "requirements",
129
- "law",
130
- "laws",
131
- }
132
- filtered_phrases = [
133
- phrase
134
- for phrase in noun_phrases
135
- if phrase.split()[0].lower() not in question_words
136
- and phrase.lower() not in generic_words
137
- ]
138
- # Combine and deduplicate
139
- keywords_set = set(found_terms + filtered_phrases)
140
- # Remove single generic words
141
- keywords_set = {
142
- kw
143
- for kw in keywords_set
144
- if kw.lower() not in question_words and kw.lower() not in generic_words
145
- }
146
- keywords = ", ".join(keywords_set)
147
- if not keywords and found_terms:
148
- keywords = found_terms[0]
149
- params = {"industry": industry, "region": region, "keywords": keywords}
150
  return params
151
 
152
  def is_regulatory_query(self, message: str) -> bool:
@@ -160,13 +60,14 @@ class RegRadarAgent:
160
  intent = call_llm(intent_prompt).strip().lower()
161
  return not intent.startswith("n")
162
 
163
- def process_regulatory_query(self, message: str):
164
  """Process a regulatory query and return results"""
165
  # Determine the intended tool
166
  tool_key, tool_name = self.determine_intended_tool(message)
167
 
168
- # Extract parameters
169
- params = self.extract_parameters(message)
 
170
 
171
  # Execute tool (crawl sites)
172
  crawl_results = self.web_tools.crawl_regulatory_sites(
 
1
  import json
2
  from typing import Dict, Tuple
3
 
4
+ from tools.llm import call_llm, stream_llm
5
  from tools.memory_tools import MemoryTools
6
  from tools.web_tools import WebTools
7
 
 
29
  return "search", "Regulatory Search"
30
 
31
  def extract_parameters(self, message: str) -> Dict:
32
+ """Extract industry, region, and keywords from the query using LLM (no function calling)."""
33
+ prompt = f"""
34
+ Extract the following information from the user query below and return ONLY a valid JSON object with keys: industry, region, keywords.
35
+ - industry: The industry mentioned or implied (e.g., fintech, healthcare, energy, general).
36
+ - region: The region or country explicitly mentioned (e.g., US, EU, UK, Asia, Global).
37
+ - keywords: The most important regulatory topics or terms, separated by commas. Do NOT include generic words or verbs.
38
+
39
+ User query: {message}
40
+
41
+ Example output:
42
+ {{"industry": "fintech", "region": "US", "keywords": "SEC regulations"}}
43
+ """
44
+ response = call_llm(prompt)
45
+ try:
46
+ params = json.loads(response)
47
+ except Exception:
48
+ # fallback: return empty/defaults if parsing fails
49
+ params = {"industry": "General", "region": "US", "keywords": ""}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return params
51
 
52
  def is_regulatory_query(self, message: str) -> bool:
 
60
  intent = call_llm(intent_prompt).strip().lower()
61
  return not intent.startswith("n")
62
 
63
+ def process_regulatory_query(self, message: str, params: dict = None):
64
  """Process a regulatory query and return results"""
65
  # Determine the intended tool
66
  tool_key, tool_name = self.determine_intended_tool(message)
67
 
68
+ # Extract parameters only if not provided
69
+ if params is None:
70
+ params = self.extract_parameters(message)
71
 
72
  # Execute tool (crawl sites)
73
  crawl_results = self.web_tools.crawl_regulatory_sites(
agents/ui_handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
 
3
  import gradio as gr
@@ -81,7 +82,7 @@ class UIHandler:
81
  yield history, "", gr.update(interactive=False)
82
 
83
  # Process the regulatory query
84
- results = self.agent.process_regulatory_query(message)
85
  crawl_results = results["crawl_results"]
86
  memory_results = results["memory_results"]
87
 
@@ -159,10 +160,7 @@ Found {len(memory_results)} similar past queries in memory.
159
  history[-1] = ChatMessage(role="assistant", content=streaming_content)
160
  yield history, "", gr.update(interactive=False)
161
 
162
- # Save to memory
163
- self.agent.memory_tools.save_to_memory("user", message, streaming_content)
164
-
165
- # Show completion time
166
  elapsed = time.time() - start_time
167
  history.append(
168
  ChatMessage(
@@ -172,6 +170,13 @@ Found {len(memory_results)} similar past queries in memory.
172
  # Re-enable input box at the end
173
  yield history, "", gr.update(interactive=True)
174
 
 
 
 
 
 
 
 
175
  def delayed_clear(self):
176
  time.sleep(0.1) # 100ms delay to allow generator cancellation
177
  return [], "", gr.update(interactive=True)
 
1
+ import threading
2
  import time
3
 
4
  import gradio as gr
 
82
  yield history, "", gr.update(interactive=False)
83
 
84
  # Process the regulatory query
85
+ results = self.agent.process_regulatory_query(message, params)
86
  crawl_results = results["crawl_results"]
87
  memory_results = results["memory_results"]
88
 
 
160
  history[-1] = ChatMessage(role="assistant", content=streaming_content)
161
  yield history, "", gr.update(interactive=False)
162
 
163
+ # Show completion time (before saving to memory)
 
 
 
164
  elapsed = time.time() - start_time
165
  history.append(
166
  ChatMessage(
 
170
  # Re-enable input box at the end
171
  yield history, "", gr.update(interactive=True)
172
 
173
+ # Save to memory in the background
174
+ threading.Thread(
175
+ target=self.agent.memory_tools.save_to_memory,
176
+ args=("user", message, streaming_content),
177
+ daemon=True,
178
+ ).start()
179
+
180
  def delayed_clear(self):
181
  time.sleep(0.1) # 100ms delay to allow generator cancellation
182
  return [], "", gr.update(interactive=True)
tools/llm.py CHANGED
@@ -43,29 +43,3 @@ def stream_llm(prompt: str, temperature: float = DEFAULT_LLM_TEMPERATURE):
43
  yield delta
44
  except Exception as e:
45
  yield f"Error: {str(e)}"
46
-
47
-
48
- def call_llm_with_function(
49
- user_message: str,
50
- function_schema: dict,
51
- temperature: float = DEFAULT_LLM_TEMPERATURE,
52
- ) -> dict:
53
- """Call the LLM with function calling and return extracted arguments as a dict."""
54
- try:
55
- response = client.chat.completions.create(
56
- model=DEFAULT_LLM_MODEL,
57
- messages=[{"role": "user", "content": user_message}],
58
- functions=[function_schema],
59
- function_call="auto",
60
- temperature=temperature,
61
- )
62
- function_call = response.choices[0].message.function_call
63
- if function_call and hasattr(function_call, "arguments"):
64
- import json
65
-
66
- return json.loads(function_call.arguments)
67
- else:
68
- return {}
69
- except Exception as e:
70
- print(f"LLM function call error: {e}")
71
- return {}
 
43
  yield delta
44
  except Exception as e:
45
  yield f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/web_tools.py CHANGED
@@ -4,11 +4,18 @@ from typing import Dict
4
  from tavily import TavilyClient
5
 
6
  from config.settings import REGULATORY_SOURCES, TAVILY_API_KEY
 
7
 
8
  # Initialize Tavily client
9
  tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
10
 
11
 
 
 
 
 
 
 
12
  class WebTools:
13
  def __init__(self):
14
  self.cached_searches = {}
@@ -78,43 +85,23 @@ class WebTools:
78
  return results
79
 
80
  def extract_parameters(self, message: str) -> Dict:
81
- """Extract industry, region, and keywords from the query using LLM function calling only"""
82
- function_schema = {
83
- "name": "extract_parameters",
84
- "description": (
85
- "Extract industry, region, and keywords from a user query.\n"
86
- "- 'industry': The main industry mentioned or implied (e.g., fintech, healthcare, energy, general).\n"
87
- "- 'region': The country or region explicitly mentioned (e.g., US, EU, UK, Asia).\n"
88
- "- 'keywords': Only the most important regulatory topics or terms (e.g., 'data privacy', 'GDPR', 'ESG compliance', 'SEC regulations'), not generic words or verbs.\n"
89
- "Examples:\n"
90
- "- 'Show me the latest SEC regulations for fintech' => industry: 'fintech', region: 'US', keywords: 'SEC regulations'\n"
91
- "- 'What are the new data privacy rules in the EU?' => industry: 'General', region: 'EU', keywords: 'data privacy'\n"
92
- "- 'Scan for healthcare regulations in the US' => industry: 'healthcare', region: 'US', keywords: 'healthcare regulations'\n"
93
- "- 'Any updates on ESG compliance for energy companies?' => industry: 'energy', region: 'US', keywords: 'ESG compliance'\n"
94
- ),
95
- "parameters": {
96
- "type": "object",
97
- "properties": {
98
- "industry": {
99
- "type": "string",
100
- "description": "The industry mentioned or implied in the query (e.g., fintech, healthcare, energy, general).",
101
- },
102
- "region": {
103
- "type": "string",
104
- "description": "The region or country explicitly mentioned in the query (e.g., US, EU, UK, Asia).",
105
- },
106
- "keywords": {
107
- "type": "string",
108
- "description": "A concise list of the most important regulatory topics or terms from the query, separated by commas. Do NOT return the full user question, generic words, or verbs.",
109
- },
110
- },
111
- "required": ["industry", "region", "keywords"],
112
- },
113
- }
114
- params = call_llm_with_function(message, function_schema)
115
- # Optionally, you can add a minimal fallback if params is None or missing keys
116
- if not params or not all(
117
- k in params for k in ("industry", "region", "keywords")
118
- ):
119
  params = {"industry": "", "region": "", "keywords": ""}
120
  return params
 
4
  from tavily import TavilyClient
5
 
6
  from config.settings import REGULATORY_SOURCES, TAVILY_API_KEY
7
+ from tools.llm import call_llm
8
 
9
  # Initialize Tavily client
10
  tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
11
 
12
 
13
+ class ChatMessage:
14
+ def __init__(self, role, content):
15
+ self.role = role
16
+ self.content = content
17
+
18
+
19
  class WebTools:
20
  def __init__(self):
21
  self.cached_searches = {}
 
85
  return results
86
 
87
  def extract_parameters(self, message: str) -> Dict:
88
+ """Extract industry, region, and keywords from the query using LLM (no function calling)."""
89
+ prompt = f"""
90
+ Extract the following information from the user query below and return ONLY a valid JSON object with keys: industry, region, keywords.
91
+ - industry: The industry mentioned or implied (e.g., fintech, healthcare, energy, general).
92
+ - region: The region or country explicitly mentioned (e.g., US, EU, UK, Asia, Global).
93
+ - keywords: The most important regulatory topics or terms, separated by commas. Do NOT include generic words or verbs.
94
+
95
+ User query: {message}
96
+
97
+ Example output:
98
+ {{"industry": "fintech", "region": "US", "keywords": "SEC regulations"}}
99
+ """
100
+ import json
101
+
102
+ response = call_llm(prompt)
103
+ try:
104
+ params = json.loads(response)
105
+ except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  params = {"industry": "", "region": "", "keywords": ""}
107
  return params