Spaces:
Sleeping
Sleeping
Priyanshi Saxena
commited on
Commit
·
c785b3f
1
Parent(s):
2fe0e75
fix: Add missing validate_gemini_response method and cleanup method
Browse files- Added validate_gemini_response method to AISafetyGuard
- Added cleanup method to ChartDataTool to prevent AttributeError
- Improved Gemini tool parsing logic to handle all suggested tools
- Updated Gemini model to use 'gemini-2.0-flash-lite'
- src/agent/research_agent.py +14 -0
- src/tools/chart_data_tool.py +5 -0
- src/utils/ai_safety.py +34 -0
src/agent/research_agent.py
CHANGED
|
@@ -417,6 +417,20 @@ Respond with only the tool names, comma-separated (no explanations)."""
|
|
| 417 |
'etherscan_data', 'chart_data_provider'
|
| 418 |
}]
|
| 419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
logger.info(f"🛠️ Gemini suggested tools: {suggested_tools}")
|
| 421 |
|
| 422 |
# Step 2: Execute tools (same logic as Ollama version)
|
|
|
|
| 417 |
'etherscan_data', 'chart_data_provider'
|
| 418 |
}]
|
| 419 |
|
| 420 |
+
# If no valid tools found, extract from response content
|
| 421 |
+
if not suggested_tools:
|
| 422 |
+
response_text = str(tool_response).lower()
|
| 423 |
+
if 'cryptocompare' in response_text:
|
| 424 |
+
suggested_tools.append('cryptocompare_data')
|
| 425 |
+
if 'coingecko' in response_text:
|
| 426 |
+
suggested_tools.append('coingecko_data')
|
| 427 |
+
if 'defillama' in response_text:
|
| 428 |
+
suggested_tools.append('defillama_data')
|
| 429 |
+
if 'etherscan' in response_text:
|
| 430 |
+
suggested_tools.append('etherscan_data')
|
| 431 |
+
if 'chart' in response_text or 'visualization' in response_text:
|
| 432 |
+
suggested_tools.append('chart_data_provider')
|
| 433 |
+
|
| 434 |
logger.info(f"🛠️ Gemini suggested tools: {suggested_tools}")
|
| 435 |
|
| 436 |
# Step 2: Execute tools (same logic as Ollama version)
|
src/tools/chart_data_tool.py
CHANGED
|
@@ -393,3 +393,8 @@ class ChartDataTool(BaseTool):
|
|
| 393 |
"1d": 1, "7d": 7, "30d": 30, "90d": 90, "365d": 365, "1y": 365
|
| 394 |
}
|
| 395 |
return timeframe_map.get(timeframe, 30)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
"1d": 1, "7d": 7, "30d": 30, "90d": 90, "365d": 365, "1y": 365
|
| 394 |
}
|
| 395 |
return timeframe_map.get(timeframe, 30)
|
| 396 |
+
|
| 397 |
+
async def cleanup(self):
|
| 398 |
+
"""Cleanup method for session management"""
|
| 399 |
+
# ChartDataTool doesn't maintain persistent connections, so nothing to clean up
|
| 400 |
+
pass
|
src/utils/ai_safety.py
CHANGED
|
@@ -130,6 +130,40 @@ class AISafetyGuard:
|
|
| 130 |
|
| 131 |
return cleaned, True, "Response is safe"
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def create_safe_prompt(self, user_query: str, tool_context: str) -> str:
|
| 134 |
"""Create a safety-enhanced prompt for Ollama"""
|
| 135 |
safety_instructions = """
|
|
|
|
| 130 |
|
| 131 |
return cleaned, True, "Response is safe"
|
| 132 |
|
| 133 |
+
def validate_gemini_response(self, response: str) -> Tuple[str, bool, str]:
|
| 134 |
+
"""
|
| 135 |
+
Validate Gemini response for safety and quality
|
| 136 |
+
Returns: (cleaned_response, is_valid, reason)
|
| 137 |
+
"""
|
| 138 |
+
if not response or not response.strip():
|
| 139 |
+
return "", False, "Empty response from Gemini"
|
| 140 |
+
|
| 141 |
+
# Check for dangerous content in response
|
| 142 |
+
dangerous_patterns = [
|
| 143 |
+
r'(?i)here.*is.*how.*to.*hack',
|
| 144 |
+
r'(?i)steps.*to.*exploit',
|
| 145 |
+
r'(?i)bypass.*security.*by',
|
| 146 |
+
r'(?i)manipulate.*market.*by',
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
for pattern in dangerous_patterns:
|
| 150 |
+
if re.search(pattern, response):
|
| 151 |
+
logger.warning(f"Blocked unsafe Gemini response: {pattern}")
|
| 152 |
+
return "", False, "Response contains potentially unsafe content"
|
| 153 |
+
|
| 154 |
+
# Basic response cleaning
|
| 155 |
+
cleaned = response.strip()
|
| 156 |
+
|
| 157 |
+
# Remove any potential HTML/JavaScript
|
| 158 |
+
cleaned = re.sub(r'<script.*?</script>', '', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
| 159 |
+
cleaned = re.sub(r'<[^>]+>', '', cleaned)
|
| 160 |
+
|
| 161 |
+
# Ensure response is within reasonable length
|
| 162 |
+
if len(cleaned) > 10000: # 10k character limit
|
| 163 |
+
cleaned = cleaned[:10000] + "\n\n[Response truncated for safety]"
|
| 164 |
+
|
| 165 |
+
return cleaned, True, "Response is safe"
|
| 166 |
+
|
| 167 |
def create_safe_prompt(self, user_query: str, tool_context: str) -> str:
|
| 168 |
"""Create a safety-enhanced prompt for Ollama"""
|
| 169 |
safety_instructions = """
|