Spaces:
Sleeping
Sleeping
Abid Ali Awan
commited on
Commit
·
48e84df
1
Parent(s):
ad9dfde
Refactor RegRadarAgent to simplify parameter extraction by removing LLM function calling. Update UIHandler to pass parameters directly to the regulatory query processing function. Enhance memory saving functionality by implementing background threading for improved performance.
Browse files- agents/reg_radar.py +23 -122
- agents/ui_handler.py +10 -5
- tools/llm.py +0 -26
- tools/web_tools.py +25 -38
agents/reg_radar.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
from typing import Dict, Tuple
|
| 3 |
|
| 4 |
-
from tools.llm import call_llm,
|
| 5 |
from tools.memory_tools import MemoryTools
|
| 6 |
from tools.web_tools import WebTools
|
| 7 |
|
|
@@ -29,124 +29,24 @@ class RegRadarAgent:
|
|
| 29 |
return "search", "Regulatory Search"
|
| 30 |
|
| 31 |
def extract_parameters(self, message: str) -> Dict:
|
| 32 |
-
"""Extract industry, region, and keywords from the query using LLM function calling"""
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
"type": "string",
|
| 51 |
-
"description": "The industry mentioned or implied in the query (e.g., fintech, healthcare, energy, general).",
|
| 52 |
-
},
|
| 53 |
-
"region": {
|
| 54 |
-
"type": "string",
|
| 55 |
-
"description": "The region or country explicitly mentioned in the query (e.g., US, EU, UK, Asia).",
|
| 56 |
-
},
|
| 57 |
-
"keywords": {
|
| 58 |
-
"type": "string",
|
| 59 |
-
"description": "A concise list of the most important regulatory topics or terms from the query, separated by commas. Do NOT return the full user question, generic words, or verbs.",
|
| 60 |
-
},
|
| 61 |
-
},
|
| 62 |
-
"required": ["industry", "region", "keywords"],
|
| 63 |
-
},
|
| 64 |
-
}
|
| 65 |
-
params = call_llm_with_function(message, function_schema)
|
| 66 |
-
# Fallback: context-aware extraction if LLM fails
|
| 67 |
-
if not params or not all(
|
| 68 |
-
k in params for k in ("industry", "region", "keywords")
|
| 69 |
-
):
|
| 70 |
-
import re
|
| 71 |
-
|
| 72 |
-
# Infer industry from context
|
| 73 |
-
industry = "General"
|
| 74 |
-
industry_map = {
|
| 75 |
-
"fintech": ["fintech", "finance", "sec", "bank", "investment"],
|
| 76 |
-
"healthcare": ["healthcare", "medical", "pharma", "hospital"],
|
| 77 |
-
"energy": ["energy", "oil", "gas", "renewable", "power"],
|
| 78 |
-
"technology": ["technology", "tech", "ai", "software", "it", "cyber"],
|
| 79 |
-
"retail": ["retail", "ecommerce", "shopping", "store"],
|
| 80 |
-
"general": [],
|
| 81 |
-
}
|
| 82 |
-
for ind, keywords in industry_map.items():
|
| 83 |
-
if any(word in message.lower() for word in keywords):
|
| 84 |
-
industry = ind
|
| 85 |
-
break
|
| 86 |
-
# Extract region
|
| 87 |
-
region_match = re.search(
|
| 88 |
-
r"\b(EU|US|UK|Asia|Europe|America|Canada|Australia|India|China|Japan|Global)\b",
|
| 89 |
-
message,
|
| 90 |
-
re.IGNORECASE,
|
| 91 |
-
)
|
| 92 |
-
region = region_match.group(1).upper() if region_match else "US"
|
| 93 |
-
# Extract keywords: regulatory terms and meaningful noun phrases only
|
| 94 |
-
regulatory_terms = [
|
| 95 |
-
"regulation",
|
| 96 |
-
"regulations",
|
| 97 |
-
"compliance",
|
| 98 |
-
"GDPR",
|
| 99 |
-
"data privacy",
|
| 100 |
-
"SEC",
|
| 101 |
-
"ESG",
|
| 102 |
-
"law",
|
| 103 |
-
"rules",
|
| 104 |
-
"requirements",
|
| 105 |
-
]
|
| 106 |
-
found_terms = [
|
| 107 |
-
term for term in regulatory_terms if term.lower() in message.lower()
|
| 108 |
-
]
|
| 109 |
-
# Multi-word capitalized noun phrases (e.g., 'data privacy', 'SEC regulations')
|
| 110 |
-
noun_phrases = re.findall(r"([A-Z][a-z]+(?: [a-z]+)+)", message)
|
| 111 |
-
# Remove question words and generic words
|
| 112 |
-
question_words = {
|
| 113 |
-
"what",
|
| 114 |
-
"which",
|
| 115 |
-
"who",
|
| 116 |
-
"whom",
|
| 117 |
-
"whose",
|
| 118 |
-
"when",
|
| 119 |
-
"where",
|
| 120 |
-
"why",
|
| 121 |
-
"how",
|
| 122 |
-
}
|
| 123 |
-
generic_words = {
|
| 124 |
-
"rules",
|
| 125 |
-
"regulation",
|
| 126 |
-
"regulations",
|
| 127 |
-
"requirement",
|
| 128 |
-
"requirements",
|
| 129 |
-
"law",
|
| 130 |
-
"laws",
|
| 131 |
-
}
|
| 132 |
-
filtered_phrases = [
|
| 133 |
-
phrase
|
| 134 |
-
for phrase in noun_phrases
|
| 135 |
-
if phrase.split()[0].lower() not in question_words
|
| 136 |
-
and phrase.lower() not in generic_words
|
| 137 |
-
]
|
| 138 |
-
# Combine and deduplicate
|
| 139 |
-
keywords_set = set(found_terms + filtered_phrases)
|
| 140 |
-
# Remove single generic words
|
| 141 |
-
keywords_set = {
|
| 142 |
-
kw
|
| 143 |
-
for kw in keywords_set
|
| 144 |
-
if kw.lower() not in question_words and kw.lower() not in generic_words
|
| 145 |
-
}
|
| 146 |
-
keywords = ", ".join(keywords_set)
|
| 147 |
-
if not keywords and found_terms:
|
| 148 |
-
keywords = found_terms[0]
|
| 149 |
-
params = {"industry": industry, "region": region, "keywords": keywords}
|
| 150 |
return params
|
| 151 |
|
| 152 |
def is_regulatory_query(self, message: str) -> bool:
|
|
@@ -160,13 +60,14 @@ class RegRadarAgent:
|
|
| 160 |
intent = call_llm(intent_prompt).strip().lower()
|
| 161 |
return not intent.startswith("n")
|
| 162 |
|
| 163 |
-
def process_regulatory_query(self, message: str):
|
| 164 |
"""Process a regulatory query and return results"""
|
| 165 |
# Determine the intended tool
|
| 166 |
tool_key, tool_name = self.determine_intended_tool(message)
|
| 167 |
|
| 168 |
-
# Extract parameters
|
| 169 |
-
params
|
|
|
|
| 170 |
|
| 171 |
# Execute tool (crawl sites)
|
| 172 |
crawl_results = self.web_tools.crawl_regulatory_sites(
|
|
|
|
| 1 |
import json
|
| 2 |
from typing import Dict, Tuple
|
| 3 |
|
| 4 |
+
from tools.llm import call_llm, stream_llm
|
| 5 |
from tools.memory_tools import MemoryTools
|
| 6 |
from tools.web_tools import WebTools
|
| 7 |
|
|
|
|
| 29 |
return "search", "Regulatory Search"
|
| 30 |
|
| 31 |
def extract_parameters(self, message: str) -> Dict:
|
| 32 |
+
"""Extract industry, region, and keywords from the query using LLM (no function calling)."""
|
| 33 |
+
prompt = f"""
|
| 34 |
+
Extract the following information from the user query below and return ONLY a valid JSON object with keys: industry, region, keywords.
|
| 35 |
+
- industry: The industry mentioned or implied (e.g., fintech, healthcare, energy, general).
|
| 36 |
+
- region: The region or country explicitly mentioned (e.g., US, EU, UK, Asia, Global).
|
| 37 |
+
- keywords: The most important regulatory topics or terms, separated by commas. Do NOT include generic words or verbs.
|
| 38 |
+
|
| 39 |
+
User query: {message}
|
| 40 |
+
|
| 41 |
+
Example output:
|
| 42 |
+
{{"industry": "fintech", "region": "US", "keywords": "SEC regulations"}}
|
| 43 |
+
"""
|
| 44 |
+
response = call_llm(prompt)
|
| 45 |
+
try:
|
| 46 |
+
params = json.loads(response)
|
| 47 |
+
except Exception:
|
| 48 |
+
# fallback: return empty/defaults if parsing fails
|
| 49 |
+
params = {"industry": "General", "region": "US", "keywords": ""}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
return params
|
| 51 |
|
| 52 |
def is_regulatory_query(self, message: str) -> bool:
|
|
|
|
| 60 |
intent = call_llm(intent_prompt).strip().lower()
|
| 61 |
return not intent.startswith("n")
|
| 62 |
|
| 63 |
+
def process_regulatory_query(self, message: str, params: dict = None):
|
| 64 |
"""Process a regulatory query and return results"""
|
| 65 |
# Determine the intended tool
|
| 66 |
tool_key, tool_name = self.determine_intended_tool(message)
|
| 67 |
|
| 68 |
+
# Extract parameters only if not provided
|
| 69 |
+
if params is None:
|
| 70 |
+
params = self.extract_parameters(message)
|
| 71 |
|
| 72 |
# Execute tool (crawl sites)
|
| 73 |
crawl_results = self.web_tools.crawl_regulatory_sites(
|
agents/ui_handler.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
|
| 3 |
import gradio as gr
|
|
@@ -81,7 +82,7 @@ class UIHandler:
|
|
| 81 |
yield history, "", gr.update(interactive=False)
|
| 82 |
|
| 83 |
# Process the regulatory query
|
| 84 |
-
results = self.agent.process_regulatory_query(message)
|
| 85 |
crawl_results = results["crawl_results"]
|
| 86 |
memory_results = results["memory_results"]
|
| 87 |
|
|
@@ -159,10 +160,7 @@ Found {len(memory_results)} similar past queries in memory.
|
|
| 159 |
history[-1] = ChatMessage(role="assistant", content=streaming_content)
|
| 160 |
yield history, "", gr.update(interactive=False)
|
| 161 |
|
| 162 |
-
#
|
| 163 |
-
self.agent.memory_tools.save_to_memory("user", message, streaming_content)
|
| 164 |
-
|
| 165 |
-
# Show completion time
|
| 166 |
elapsed = time.time() - start_time
|
| 167 |
history.append(
|
| 168 |
ChatMessage(
|
|
@@ -172,6 +170,13 @@ Found {len(memory_results)} similar past queries in memory.
|
|
| 172 |
# Re-enable input box at the end
|
| 173 |
yield history, "", gr.update(interactive=True)
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def delayed_clear(self):
|
| 176 |
time.sleep(0.1) # 100ms delay to allow generator cancellation
|
| 177 |
return [], "", gr.update(interactive=True)
|
|
|
|
| 1 |
+
import threading
|
| 2 |
import time
|
| 3 |
|
| 4 |
import gradio as gr
|
|
|
|
| 82 |
yield history, "", gr.update(interactive=False)
|
| 83 |
|
| 84 |
# Process the regulatory query
|
| 85 |
+
results = self.agent.process_regulatory_query(message, params)
|
| 86 |
crawl_results = results["crawl_results"]
|
| 87 |
memory_results = results["memory_results"]
|
| 88 |
|
|
|
|
| 160 |
history[-1] = ChatMessage(role="assistant", content=streaming_content)
|
| 161 |
yield history, "", gr.update(interactive=False)
|
| 162 |
|
| 163 |
+
# Show completion time (before saving to memory)
|
|
|
|
|
|
|
|
|
|
| 164 |
elapsed = time.time() - start_time
|
| 165 |
history.append(
|
| 166 |
ChatMessage(
|
|
|
|
| 170 |
# Re-enable input box at the end
|
| 171 |
yield history, "", gr.update(interactive=True)
|
| 172 |
|
| 173 |
+
# Save to memory in the background
|
| 174 |
+
threading.Thread(
|
| 175 |
+
target=self.agent.memory_tools.save_to_memory,
|
| 176 |
+
args=("user", message, streaming_content),
|
| 177 |
+
daemon=True,
|
| 178 |
+
).start()
|
| 179 |
+
|
| 180 |
def delayed_clear(self):
|
| 181 |
time.sleep(0.1) # 100ms delay to allow generator cancellation
|
| 182 |
return [], "", gr.update(interactive=True)
|
tools/llm.py
CHANGED
|
@@ -43,29 +43,3 @@ def stream_llm(prompt: str, temperature: float = DEFAULT_LLM_TEMPERATURE):
|
|
| 43 |
yield delta
|
| 44 |
except Exception as e:
|
| 45 |
yield f"Error: {str(e)}"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def call_llm_with_function(
|
| 49 |
-
user_message: str,
|
| 50 |
-
function_schema: dict,
|
| 51 |
-
temperature: float = DEFAULT_LLM_TEMPERATURE,
|
| 52 |
-
) -> dict:
|
| 53 |
-
"""Call the LLM with function calling and return extracted arguments as a dict."""
|
| 54 |
-
try:
|
| 55 |
-
response = client.chat.completions.create(
|
| 56 |
-
model=DEFAULT_LLM_MODEL,
|
| 57 |
-
messages=[{"role": "user", "content": user_message}],
|
| 58 |
-
functions=[function_schema],
|
| 59 |
-
function_call="auto",
|
| 60 |
-
temperature=temperature,
|
| 61 |
-
)
|
| 62 |
-
function_call = response.choices[0].message.function_call
|
| 63 |
-
if function_call and hasattr(function_call, "arguments"):
|
| 64 |
-
import json
|
| 65 |
-
|
| 66 |
-
return json.loads(function_call.arguments)
|
| 67 |
-
else:
|
| 68 |
-
return {}
|
| 69 |
-
except Exception as e:
|
| 70 |
-
print(f"LLM function call error: {e}")
|
| 71 |
-
return {}
|
|
|
|
| 43 |
yield delta
|
| 44 |
except Exception as e:
|
| 45 |
yield f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/web_tools.py
CHANGED
|
@@ -4,11 +4,18 @@ from typing import Dict
|
|
| 4 |
from tavily import TavilyClient
|
| 5 |
|
| 6 |
from config.settings import REGULATORY_SOURCES, TAVILY_API_KEY
|
|
|
|
| 7 |
|
| 8 |
# Initialize Tavily client
|
| 9 |
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class WebTools:
|
| 13 |
def __init__(self):
|
| 14 |
self.cached_searches = {}
|
|
@@ -78,43 +85,23 @@ class WebTools:
|
|
| 78 |
return results
|
| 79 |
|
| 80 |
def extract_parameters(self, message: str) -> Dict:
|
| 81 |
-
"""Extract industry, region, and keywords from the query using LLM function calling
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
"type": "string",
|
| 100 |
-
"description": "The industry mentioned or implied in the query (e.g., fintech, healthcare, energy, general).",
|
| 101 |
-
},
|
| 102 |
-
"region": {
|
| 103 |
-
"type": "string",
|
| 104 |
-
"description": "The region or country explicitly mentioned in the query (e.g., US, EU, UK, Asia).",
|
| 105 |
-
},
|
| 106 |
-
"keywords": {
|
| 107 |
-
"type": "string",
|
| 108 |
-
"description": "A concise list of the most important regulatory topics or terms from the query, separated by commas. Do NOT return the full user question, generic words, or verbs.",
|
| 109 |
-
},
|
| 110 |
-
},
|
| 111 |
-
"required": ["industry", "region", "keywords"],
|
| 112 |
-
},
|
| 113 |
-
}
|
| 114 |
-
params = call_llm_with_function(message, function_schema)
|
| 115 |
-
# Optionally, you can add a minimal fallback if params is None or missing keys
|
| 116 |
-
if not params or not all(
|
| 117 |
-
k in params for k in ("industry", "region", "keywords")
|
| 118 |
-
):
|
| 119 |
params = {"industry": "", "region": "", "keywords": ""}
|
| 120 |
return params
|
|
|
|
| 4 |
from tavily import TavilyClient
|
| 5 |
|
| 6 |
from config.settings import REGULATORY_SOURCES, TAVILY_API_KEY
|
| 7 |
+
from tools.llm import call_llm
|
| 8 |
|
| 9 |
# Initialize Tavily client
|
| 10 |
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
|
| 11 |
|
| 12 |
|
| 13 |
+
class ChatMessage:
|
| 14 |
+
def __init__(self, role, content):
|
| 15 |
+
self.role = role
|
| 16 |
+
self.content = content
|
| 17 |
+
|
| 18 |
+
|
| 19 |
class WebTools:
|
| 20 |
def __init__(self):
|
| 21 |
self.cached_searches = {}
|
|
|
|
| 85 |
return results
|
| 86 |
|
| 87 |
def extract_parameters(self, message: str) -> Dict:
|
| 88 |
+
"""Extract industry, region, and keywords from the query using LLM (no function calling)."""
|
| 89 |
+
prompt = f"""
|
| 90 |
+
Extract the following information from the user query below and return ONLY a valid JSON object with keys: industry, region, keywords.
|
| 91 |
+
- industry: The industry mentioned or implied (e.g., fintech, healthcare, energy, general).
|
| 92 |
+
- region: The region or country explicitly mentioned (e.g., US, EU, UK, Asia, Global).
|
| 93 |
+
- keywords: The most important regulatory topics or terms, separated by commas. Do NOT include generic words or verbs.
|
| 94 |
+
|
| 95 |
+
User query: {message}
|
| 96 |
+
|
| 97 |
+
Example output:
|
| 98 |
+
{{"industry": "fintech", "region": "US", "keywords": "SEC regulations"}}
|
| 99 |
+
"""
|
| 100 |
+
import json
|
| 101 |
+
|
| 102 |
+
response = call_llm(prompt)
|
| 103 |
+
try:
|
| 104 |
+
params = json.loads(response)
|
| 105 |
+
except Exception:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
params = {"industry": "", "region": "", "keywords": ""}
|
| 107 |
return params
|