Spaces:
Sleeping
Sleeping
Commit
·
ef83e66
1
Parent(s):
67b7db4
Reasoning traces, smarter tools, deterministic backend tests.
Browse files- README.md +1 -1
- backend/api/models/agent.py +1 -0
- backend/api/models/redflag.py +2 -0
- backend/api/services/agent_orchestrator.py +178 -16
- backend/api/services/redflag_detector.py +53 -1
- backend/api/services/semantic_encoder.py +62 -0
- backend/api/services/tool_scoring.py +54 -0
- backend/api/services/tool_selector.py +9 -4
- backend/mcp_servers/database.py +10 -2
- backend/mcp_servers/rag_server.py +34 -1
- backend/tests/test_agent_orchestrator.py +13 -3
- frontend/.gitignore +41 -0
- frontend/README.md +36 -0
- frontend/app/favicon.ico +0 -0
- frontend/app/globals.css +26 -0
- frontend/app/layout.tsx +34 -0
- frontend/app/page.tsx +65 -0
- frontend/eslint.config.mjs +18 -0
- frontend/next.config.ts +7 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +26 -0
- frontend/postcss.config.mjs +7 -0
- frontend/public/file.svg +1 -0
- frontend/public/globe.svg +1 -0
- frontend/public/next.svg +1 -0
- frontend/public/vercel.svg +1 -0
- frontend/public/window.svg +1 -0
- frontend/tsconfig.json +34 -0
- test_all.py +233 -0
README.md
CHANGED
|
@@ -357,7 +357,7 @@ Before you begin, ensure you have the following installed:
|
|
| 357 |
|
| 358 |
Create a `.env` file in the project root with the following:
|
| 359 |
```env
|
| 360 |
-
# Database
|
| 361 |
POSTGRESQL_URL=postgresql://user:password@host:port/database
|
| 362 |
SUPABASE_URL=https://your-project.supabase.co
|
| 363 |
SUPABASE_SERVICE_KEY=your_service_role_key
|
|
|
|
| 357 |
|
| 358 |
Create a `.env` file in the project root with the following:
|
| 359 |
```env
|
| 360 |
+
# Database Configurationa
|
| 361 |
POSTGRESQL_URL=postgresql://user:password@host:port/database
|
| 362 |
SUPABASE_URL=https://your-project.supabase.co
|
| 363 |
SUPABASE_SERVICE_KEY=your_service_role_key
|
backend/api/models/agent.py
CHANGED
|
@@ -21,4 +21,5 @@ class AgentResponse(BaseModel):
|
|
| 21 |
text: str
|
| 22 |
decision: AgentDecision
|
| 23 |
tool_traces: List[Dict[str, Any]] = []
|
|
|
|
| 24 |
|
|
|
|
| 21 |
text: str
|
| 22 |
decision: AgentDecision
|
| 23 |
tool_traces: List[Dict[str, Any]] = []
|
| 24 |
+
reasoning_trace: List[Dict[str, Any]] = []
|
| 25 |
|
backend/api/models/redflag.py
CHANGED
|
@@ -20,4 +20,6 @@ class RedFlagMatch:
|
|
| 20 |
severity: str
|
| 21 |
description: str
|
| 22 |
matched_text: str
|
|
|
|
|
|
|
| 23 |
|
|
|
|
| 20 |
severity: str
|
| 21 |
description: str
|
| 22 |
matched_text: str
|
| 23 |
+
confidence: float | None = None
|
| 24 |
+
explanation: str | None = None
|
| 25 |
|
backend/api/services/agent_orchestrator.py
CHANGED
|
@@ -9,6 +9,7 @@ Place at: backend/api/services/agent_orchestrator.py
|
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
|
|
|
| 12 |
import json
|
| 13 |
import os
|
| 14 |
from typing import List, Dict, Any, Optional
|
|
@@ -20,6 +21,7 @@ from .intent_classifier import IntentClassifier
|
|
| 20 |
from .tool_selector import ToolSelector
|
| 21 |
from .llm_client import LLMClient
|
| 22 |
from ..mcp_clients.mcp_client import MCPClient
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class AgentOrchestrator:
|
|
@@ -37,10 +39,24 @@ class AgentOrchestrator:
|
|
| 37 |
|
| 38 |
self.intent = IntentClassifier(llm_client=self.llm)
|
| 39 |
self.selector = ToolSelector(llm_client=self.llm)
|
|
|
|
| 40 |
|
| 41 |
async def handle(self, req: AgentRequest) -> AgentResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# 1) Red-flag check (async)
|
| 43 |
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if matches:
|
| 46 |
# Notify admin asynchronously (do not await blocking the response path if you prefer)
|
|
@@ -59,11 +75,16 @@ class AgentOrchestrator:
|
|
| 59 |
return AgentResponse(
|
| 60 |
text="Your request has been blocked due to policy.",
|
| 61 |
decision=decision,
|
| 62 |
-
tool_traces=[{"redflags": [m.__dict__ for m in matches]}]
|
|
|
|
| 63 |
)
|
| 64 |
|
| 65 |
# 2) Intent classification
|
| 66 |
intent = await self.intent.classify(req.message)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# 2.5) Pre-fetch RAG results if available (for tool selector context)
|
| 69 |
rag_prefetch = None
|
|
@@ -73,16 +94,38 @@ class AgentOrchestrator:
|
|
| 73 |
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
|
| 74 |
if isinstance(rag_prefetch, dict):
|
| 75 |
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# If RAG fails, continue without it
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# 3) Tool selection (hybrid) - pass RAG results in context
|
| 81 |
ctx = {
|
| 82 |
"tenant_id": req.tenant_id,
|
| 83 |
-
"rag_results": rag_results
|
|
|
|
| 84 |
}
|
| 85 |
decision = await self.selector.select(intent, req.message, ctx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
tool_traces: List[Dict[str, Any]] = []
|
| 88 |
|
|
@@ -90,7 +133,14 @@ class AgentOrchestrator:
|
|
| 90 |
if decision.action == "multi_step" and decision.tool_input:
|
| 91 |
steps = decision.tool_input.get("steps", [])
|
| 92 |
if steps:
|
| 93 |
-
return await self._execute_multi_step(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# 5) Execute single tool
|
| 96 |
if decision.action == "call_tool" and decision.tool:
|
|
@@ -98,25 +148,54 @@ class AgentOrchestrator:
|
|
| 98 |
if decision.tool == "rag":
|
| 99 |
rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 100 |
tool_traces.append({"tool": "rag", "response": rag_resp})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
prompt = self._build_prompt_with_rag(req, rag_resp)
|
| 102 |
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
if decision.tool == "web":
|
| 106 |
web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 107 |
tool_traces.append({"tool": "web", "response": web_resp})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
prompt = self._build_prompt_with_web(req, web_resp)
|
| 109 |
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
if decision.tool == "admin":
|
| 113 |
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 114 |
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
if decision.tool == "llm":
|
| 118 |
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
except Exception as e:
|
| 122 |
tool_traces.append({"tool": decision.tool, "error": str(e)})
|
|
@@ -127,7 +206,12 @@ class AgentOrchestrator:
|
|
| 127 |
return AgentResponse(
|
| 128 |
text=fallback,
|
| 129 |
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
|
| 130 |
-
tool_traces=tool_traces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
)
|
| 132 |
|
| 133 |
# Default: direct LLM response
|
|
@@ -136,10 +220,16 @@ class AgentOrchestrator:
|
|
| 136 |
except Exception as e:
|
| 137 |
# If LLM fails, return a helpful error message
|
| 138 |
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
return AgentResponse(
|
| 141 |
text=llm_out,
|
| 142 |
-
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm")
|
|
|
|
| 143 |
)
|
| 144 |
|
| 145 |
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
|
|
@@ -160,6 +250,7 @@ class AgentOrchestrator:
|
|
| 160 |
|
| 161 |
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
|
| 162 |
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
|
|
|
|
| 163 |
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
|
| 164 |
"""
|
| 165 |
Execute multiple tools in sequence and synthesize results with LLM.
|
|
@@ -169,6 +260,14 @@ class AgentOrchestrator:
|
|
| 169 |
admin_data = None
|
| 170 |
collected_data = []
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
# Execute each step in sequence
|
| 173 |
for step_info in steps:
|
| 174 |
tool_name = step_info.get("tool")
|
|
@@ -178,13 +277,22 @@ class AgentOrchestrator:
|
|
| 178 |
try:
|
| 179 |
if tool_name == "rag":
|
| 180 |
# Reuse pre-fetched RAG if available, otherwise fetch
|
| 181 |
-
if pre_fetched_rag:
|
| 182 |
rag_resp = pre_fetched_rag
|
| 183 |
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
|
|
|
|
|
|
|
|
|
|
| 184 |
else:
|
| 185 |
rag_resp = await self.mcp.call_rag(req.tenant_id, query)
|
| 186 |
tool_traces.append({"tool": "rag", "response": rag_resp})
|
| 187 |
rag_data = rag_resp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
# Extract snippets for prompt
|
| 189 |
if isinstance(rag_resp, dict):
|
| 190 |
hits = rag_resp.get("results") or rag_resp.get("hits") or []
|
|
@@ -193,9 +301,19 @@ class AgentOrchestrator:
|
|
| 193 |
collected_data.append(f"[RAG] {txt}")
|
| 194 |
|
| 195 |
elif tool_name == "web":
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
web_data = web_resp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
# Extract snippets for prompt
|
| 200 |
if isinstance(web_resp, dict):
|
| 201 |
hits = web_resp.get("results") or web_resp.get("items") or []
|
|
@@ -210,6 +328,11 @@ class AgentOrchestrator:
|
|
| 210 |
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 211 |
admin_data = admin_resp
|
| 212 |
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
elif tool_name == "llm":
|
| 215 |
# LLM is always last - synthesize all collected data
|
|
@@ -218,6 +341,11 @@ class AgentOrchestrator:
|
|
| 218 |
except Exception as e:
|
| 219 |
tool_traces.append({"tool": tool_name, "error": str(e)})
|
| 220 |
# Continue with other tools even if one fails
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# Build comprehensive prompt with all collected data
|
| 223 |
data_section = "\n---\n".join(collected_data) if collected_data else ""
|
|
@@ -241,7 +369,11 @@ class AgentOrchestrator:
|
|
| 241 |
return AgentResponse(
|
| 242 |
text=llm_out,
|
| 243 |
decision=decision,
|
| 244 |
-
tool_traces=tool_traces
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
)
|
| 246 |
except Exception as e:
|
| 247 |
tool_traces.append({"tool": "llm", "error": str(e)})
|
|
@@ -254,7 +386,12 @@ class AgentOrchestrator:
|
|
| 254 |
tool_input=None,
|
| 255 |
reason=f"multi_step_llm_error: {e}"
|
| 256 |
),
|
| 257 |
-
tool_traces=tool_traces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
)
|
| 259 |
|
| 260 |
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
|
|
@@ -273,3 +410,28 @@ class AgentOrchestrator:
|
|
| 273 |
f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
|
| 274 |
)
|
| 275 |
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
+
import asyncio
|
| 13 |
import json
|
| 14 |
import os
|
| 15 |
from typing import List, Dict, Any, Optional
|
|
|
|
| 21 |
from .tool_selector import ToolSelector
|
| 22 |
from .llm_client import LLMClient
|
| 23 |
from ..mcp_clients.mcp_client import MCPClient
|
| 24 |
+
from .tool_scoring import ToolScoringService
|
| 25 |
|
| 26 |
|
| 27 |
class AgentOrchestrator:
|
|
|
|
| 39 |
|
| 40 |
self.intent = IntentClassifier(llm_client=self.llm)
|
| 41 |
self.selector = ToolSelector(llm_client=self.llm)
|
| 42 |
+
self.tool_scorer = ToolScoringService()
|
| 43 |
|
| 44 |
async def handle(self, req: AgentRequest) -> AgentResponse:
|
| 45 |
+
reasoning_trace: List[Dict[str, Any]] = []
|
| 46 |
+
reasoning_trace.append({
|
| 47 |
+
"step": "request_received",
|
| 48 |
+
"tenant_id": req.tenant_id,
|
| 49 |
+
"user_id": req.user_id,
|
| 50 |
+
"message_preview": req.message[:120]
|
| 51 |
+
})
|
| 52 |
+
|
| 53 |
# 1) Red-flag check (async)
|
| 54 |
matches: List[RedFlagMatch] = await self.redflag.check(req.tenant_id, req.message)
|
| 55 |
+
reasoning_trace.append({
|
| 56 |
+
"step": "redflag_check",
|
| 57 |
+
"match_count": len(matches),
|
| 58 |
+
"matches": [m.__dict__ for m in matches]
|
| 59 |
+
})
|
| 60 |
|
| 61 |
if matches:
|
| 62 |
# Notify admin asynchronously (do not await blocking the response path if you prefer)
|
|
|
|
| 75 |
return AgentResponse(
|
| 76 |
text="Your request has been blocked due to policy.",
|
| 77 |
decision=decision,
|
| 78 |
+
tool_traces=[{"redflags": [m.__dict__ for m in matches]}],
|
| 79 |
+
reasoning_trace=reasoning_trace
|
| 80 |
)
|
| 81 |
|
| 82 |
# 2) Intent classification
|
| 83 |
intent = await self.intent.classify(req.message)
|
| 84 |
+
reasoning_trace.append({
|
| 85 |
+
"step": "intent_detection",
|
| 86 |
+
"intent": intent
|
| 87 |
+
})
|
| 88 |
|
| 89 |
# 2.5) Pre-fetch RAG results if available (for tool selector context)
|
| 90 |
rag_prefetch = None
|
|
|
|
| 94 |
rag_prefetch = await self.mcp.call_rag(req.tenant_id, req.message)
|
| 95 |
if isinstance(rag_prefetch, dict):
|
| 96 |
rag_results = rag_prefetch.get("results") or rag_prefetch.get("hits") or []
|
| 97 |
+
reasoning_trace.append({
|
| 98 |
+
"step": "rag_prefetch",
|
| 99 |
+
"status": "ok",
|
| 100 |
+
"hit_count": len(rag_results)
|
| 101 |
+
})
|
| 102 |
+
except Exception as pref_err:
|
| 103 |
# If RAG fails, continue without it
|
| 104 |
+
reasoning_trace.append({
|
| 105 |
+
"step": "rag_prefetch",
|
| 106 |
+
"status": "error",
|
| 107 |
+
"error": str(pref_err)
|
| 108 |
+
})
|
| 109 |
+
rag_prefetch = None
|
| 110 |
+
|
| 111 |
+
tool_scores = self.tool_scorer.score(req.message, intent, rag_results)
|
| 112 |
+
reasoning_trace.append({
|
| 113 |
+
"step": "tool_scoring",
|
| 114 |
+
"scores": tool_scores
|
| 115 |
+
})
|
| 116 |
|
| 117 |
# 3) Tool selection (hybrid) - pass RAG results in context
|
| 118 |
ctx = {
|
| 119 |
"tenant_id": req.tenant_id,
|
| 120 |
+
"rag_results": rag_results,
|
| 121 |
+
"tool_scores": tool_scores
|
| 122 |
}
|
| 123 |
decision = await self.selector.select(intent, req.message, ctx)
|
| 124 |
+
reasoning_trace.append({
|
| 125 |
+
"step": "tool_selection",
|
| 126 |
+
"decision": decision.dict(),
|
| 127 |
+
"context_scores": tool_scores
|
| 128 |
+
})
|
| 129 |
|
| 130 |
tool_traces: List[Dict[str, Any]] = []
|
| 131 |
|
|
|
|
| 133 |
if decision.action == "multi_step" and decision.tool_input:
|
| 134 |
steps = decision.tool_input.get("steps", [])
|
| 135 |
if steps:
|
| 136 |
+
return await self._execute_multi_step(
|
| 137 |
+
req,
|
| 138 |
+
steps,
|
| 139 |
+
decision,
|
| 140 |
+
tool_traces,
|
| 141 |
+
reasoning_trace,
|
| 142 |
+
rag_prefetch
|
| 143 |
+
)
|
| 144 |
|
| 145 |
# 5) Execute single tool
|
| 146 |
if decision.action == "call_tool" and decision.tool:
|
|
|
|
| 148 |
if decision.tool == "rag":
|
| 149 |
rag_resp = await self.mcp.call_rag(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 150 |
tool_traces.append({"tool": "rag", "response": rag_resp})
|
| 151 |
+
reasoning_trace.append({
|
| 152 |
+
"step": "tool_execution",
|
| 153 |
+
"tool": "rag",
|
| 154 |
+
"hit_count": len(self._extract_hits(rag_resp)),
|
| 155 |
+
"summary": self._summarize_hits(rag_resp, limit=2)
|
| 156 |
+
})
|
| 157 |
prompt = self._build_prompt_with_rag(req, rag_resp)
|
| 158 |
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 159 |
+
reasoning_trace.append({
|
| 160 |
+
"step": "llm_response",
|
| 161 |
+
"mode": "rag_synthesis"
|
| 162 |
+
})
|
| 163 |
+
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
|
| 164 |
|
| 165 |
if decision.tool == "web":
|
| 166 |
web_resp = await self.mcp.call_web(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 167 |
tool_traces.append({"tool": "web", "response": web_resp})
|
| 168 |
+
reasoning_trace.append({
|
| 169 |
+
"step": "tool_execution",
|
| 170 |
+
"tool": "web",
|
| 171 |
+
"hit_count": len(self._extract_hits(web_resp)),
|
| 172 |
+
"summary": self._summarize_hits(web_resp, limit=2)
|
| 173 |
+
})
|
| 174 |
prompt = self._build_prompt_with_web(req, web_resp)
|
| 175 |
llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
|
| 176 |
+
reasoning_trace.append({
|
| 177 |
+
"step": "llm_response",
|
| 178 |
+
"mode": "web_synthesis"
|
| 179 |
+
})
|
| 180 |
+
return AgentResponse(text=llm_out, decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
|
| 181 |
|
| 182 |
if decision.tool == "admin":
|
| 183 |
admin_resp = await self.mcp.call_admin(req.tenant_id, decision.tool_input.get("query") if decision.tool_input else req.message)
|
| 184 |
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 185 |
+
reasoning_trace.append({
|
| 186 |
+
"step": "tool_execution",
|
| 187 |
+
"tool": "admin",
|
| 188 |
+
"status": "completed"
|
| 189 |
+
})
|
| 190 |
+
return AgentResponse(text=json.dumps(admin_resp), decision=decision, tool_traces=tool_traces, reasoning_trace=reasoning_trace)
|
| 191 |
|
| 192 |
if decision.tool == "llm":
|
| 193 |
llm_out = await self.llm.simple_call(req.message, temperature=req.temperature)
|
| 194 |
+
reasoning_trace.append({
|
| 195 |
+
"step": "llm_response",
|
| 196 |
+
"mode": "direct"
|
| 197 |
+
})
|
| 198 |
+
return AgentResponse(text=llm_out, decision=decision, reasoning_trace=reasoning_trace)
|
| 199 |
|
| 200 |
except Exception as e:
|
| 201 |
tool_traces.append({"tool": decision.tool, "error": str(e)})
|
|
|
|
| 206 |
return AgentResponse(
|
| 207 |
text=fallback,
|
| 208 |
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason=f"tool_error_fallback: {e}"),
|
| 209 |
+
tool_traces=tool_traces,
|
| 210 |
+
reasoning_trace=reasoning_trace + [{
|
| 211 |
+
"step": "error",
|
| 212 |
+
"tool": decision.tool,
|
| 213 |
+
"error": str(e)
|
| 214 |
+
}]
|
| 215 |
)
|
| 216 |
|
| 217 |
# Default: direct LLM response
|
|
|
|
| 220 |
except Exception as e:
|
| 221 |
# If LLM fails, return a helpful error message
|
| 222 |
llm_out = f"I apologize, but I'm unable to process your request right now. The AI service is unavailable: {str(e)}"
|
| 223 |
+
reasoning_trace.append({
|
| 224 |
+
"step": "error",
|
| 225 |
+
"tool": "llm",
|
| 226 |
+
"error": str(e)
|
| 227 |
+
})
|
| 228 |
|
| 229 |
return AgentResponse(
|
| 230 |
text=llm_out,
|
| 231 |
+
decision=AgentDecision(action="respond", tool=None, tool_input=None, reason="default_llm"),
|
| 232 |
+
reasoning_trace=reasoning_trace
|
| 233 |
)
|
| 234 |
|
| 235 |
def _build_prompt_with_rag(self, req: AgentRequest, rag_resp: Dict[str, Any]) -> str:
|
|
|
|
| 250 |
|
| 251 |
async def _execute_multi_step(self, req: AgentRequest, steps: List[Dict[str, Any]],
|
| 252 |
decision: AgentDecision, tool_traces: List[Dict[str, Any]],
|
| 253 |
+
reasoning_trace: List[Dict[str, Any]],
|
| 254 |
pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
|
| 255 |
"""
|
| 256 |
Execute multiple tools in sequence and synthesize results with LLM.
|
|
|
|
| 260 |
admin_data = None
|
| 261 |
collected_data = []
|
| 262 |
|
| 263 |
+
parallel_tasks = {}
|
| 264 |
+
rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
|
| 265 |
+
web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
|
| 266 |
+
if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
|
| 267 |
+
if not pre_fetched_rag:
|
| 268 |
+
parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
|
| 269 |
+
parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
|
| 270 |
+
|
| 271 |
# Execute each step in sequence
|
| 272 |
for step_info in steps:
|
| 273 |
tool_name = step_info.get("tool")
|
|
|
|
| 277 |
try:
|
| 278 |
if tool_name == "rag":
|
| 279 |
# Reuse pre-fetched RAG if available, otherwise fetch
|
| 280 |
+
if pre_fetched_rag and query == rag_parallel_query:
|
| 281 |
rag_resp = pre_fetched_rag
|
| 282 |
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
|
| 283 |
+
elif parallel_tasks.get("rag") and query == rag_parallel_query:
|
| 284 |
+
rag_resp = await parallel_tasks["rag"]
|
| 285 |
+
tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
|
| 286 |
else:
|
| 287 |
rag_resp = await self.mcp.call_rag(req.tenant_id, query)
|
| 288 |
tool_traces.append({"tool": "rag", "response": rag_resp})
|
| 289 |
rag_data = rag_resp
|
| 290 |
+
reasoning_trace.append({
|
| 291 |
+
"step": "tool_execution",
|
| 292 |
+
"tool": "rag",
|
| 293 |
+
"hit_count": len(self._extract_hits(rag_resp)),
|
| 294 |
+
"summary": self._summarize_hits(rag_resp, limit=2)
|
| 295 |
+
})
|
| 296 |
# Extract snippets for prompt
|
| 297 |
if isinstance(rag_resp, dict):
|
| 298 |
hits = rag_resp.get("results") or rag_resp.get("hits") or []
|
|
|
|
| 301 |
collected_data.append(f"[RAG] {txt}")
|
| 302 |
|
| 303 |
elif tool_name == "web":
|
| 304 |
+
if parallel_tasks.get("web") and query == web_parallel_query:
|
| 305 |
+
web_resp = await parallel_tasks["web"]
|
| 306 |
+
tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
|
| 307 |
+
else:
|
| 308 |
+
web_resp = await self.mcp.call_web(req.tenant_id, query)
|
| 309 |
+
tool_traces.append({"tool": "web", "response": web_resp})
|
| 310 |
web_data = web_resp
|
| 311 |
+
reasoning_trace.append({
|
| 312 |
+
"step": "tool_execution",
|
| 313 |
+
"tool": "web",
|
| 314 |
+
"hit_count": len(self._extract_hits(web_resp)),
|
| 315 |
+
"summary": self._summarize_hits(web_resp, limit=2)
|
| 316 |
+
})
|
| 317 |
# Extract snippets for prompt
|
| 318 |
if isinstance(web_resp, dict):
|
| 319 |
hits = web_resp.get("results") or web_resp.get("items") or []
|
|
|
|
| 328 |
tool_traces.append({"tool": "admin", "response": admin_resp})
|
| 329 |
admin_data = admin_resp
|
| 330 |
collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
|
| 331 |
+
reasoning_trace.append({
|
| 332 |
+
"step": "tool_execution",
|
| 333 |
+
"tool": "admin",
|
| 334 |
+
"status": "completed"
|
| 335 |
+
})
|
| 336 |
|
| 337 |
elif tool_name == "llm":
|
| 338 |
# LLM is always last - synthesize all collected data
|
|
|
|
| 341 |
except Exception as e:
|
| 342 |
tool_traces.append({"tool": tool_name, "error": str(e)})
|
| 343 |
# Continue with other tools even if one fails
|
| 344 |
+
reasoning_trace.append({
|
| 345 |
+
"step": "error",
|
| 346 |
+
"tool": tool_name,
|
| 347 |
+
"error": str(e)
|
| 348 |
+
})
|
| 349 |
|
| 350 |
# Build comprehensive prompt with all collected data
|
| 351 |
data_section = "\n---\n".join(collected_data) if collected_data else ""
|
|
|
|
| 369 |
return AgentResponse(
|
| 370 |
text=llm_out,
|
| 371 |
decision=decision,
|
| 372 |
+
tool_traces=tool_traces,
|
| 373 |
+
reasoning_trace=reasoning_trace + [{
|
| 374 |
+
"step": "llm_response",
|
| 375 |
+
"mode": "multi_step"
|
| 376 |
+
}]
|
| 377 |
)
|
| 378 |
except Exception as e:
|
| 379 |
tool_traces.append({"tool": "llm", "error": str(e)})
|
|
|
|
| 386 |
tool_input=None,
|
| 387 |
reason=f"multi_step_llm_error: {e}"
|
| 388 |
),
|
| 389 |
+
tool_traces=tool_traces,
|
| 390 |
+
reasoning_trace=reasoning_trace + [{
|
| 391 |
+
"step": "error",
|
| 392 |
+
"tool": "llm",
|
| 393 |
+
"error": str(e)
|
| 394 |
+
}]
|
| 395 |
)
|
| 396 |
|
| 397 |
def _build_prompt_with_web(self, req: AgentRequest, web_resp: Dict[str, Any]) -> str:
|
|
|
|
| 410 |
f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
|
| 411 |
)
|
| 412 |
return prompt
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def _extract_hits(resp: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 416 |
+
if not isinstance(resp, dict):
|
| 417 |
+
return []
|
| 418 |
+
return resp.get("results") or resp.get("hits") or resp.get("items") or []
|
| 419 |
+
|
| 420 |
+
def _summarize_hits(self, resp: Optional[Dict[str, Any]], limit: int = 3) -> List[str]:
|
| 421 |
+
hits = self._extract_hits(resp)
|
| 422 |
+
summaries = []
|
| 423 |
+
for hit in hits[:limit]:
|
| 424 |
+
if isinstance(hit, dict):
|
| 425 |
+
snippet = hit.get("text") or hit.get("content") or hit.get("snippet") or ""
|
| 426 |
+
else:
|
| 427 |
+
snippet = str(hit)
|
| 428 |
+
summaries.append(snippet[:160])
|
| 429 |
+
return summaries
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
|
| 433 |
+
for step in steps:
|
| 434 |
+
if step.get("tool") == tool_name:
|
| 435 |
+
input_data = step.get("input") or {}
|
| 436 |
+
return input_data.get("query") or default_query
|
| 437 |
+
return None
|
backend/api/services/redflag_detector.py
CHANGED
|
@@ -14,11 +14,12 @@ Enterprise RedFlagDetector
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
import time
|
| 17 |
-
from dataclasses import dataclass
|
| 18 |
from typing import List, Dict, Any, Optional
|
|
|
|
| 19 |
import httpx
|
| 20 |
|
| 21 |
from ..models.redflag import RedFlagRule, RedFlagMatch
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
class RedFlagDetector:
|
|
@@ -29,6 +30,7 @@ class RedFlagDetector:
|
|
| 29 |
self.admin_mcp_url = admin_mcp_url or os.getenv("ADMIN_MCP_URL")
|
| 30 |
self.cache_ttl = cache_ttl
|
| 31 |
self._rules_cache: Dict[str, Dict[str, Any]] = {} # tenant_id -> {"fetched_at":ts, "rules":[...]}
|
|
|
|
| 32 |
self._client = httpx.AsyncClient(timeout=15)
|
| 33 |
|
| 34 |
async def _fetch_rules_from_supabase(self, tenant_id: str) -> List[RedFlagRule]:
|
|
@@ -84,6 +86,17 @@ class RedFlagDetector:
|
|
| 84 |
|
| 85 |
rules = await self._fetch_rules_from_supabase(tenant_id)
|
| 86 |
self._rules_cache[tenant_id] = {"fetched_at": now, "rules": rules}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return rules
|
| 88 |
|
| 89 |
async def check(self, tenant_id: str, text: str) -> List[RedFlagMatch]:
|
|
@@ -95,6 +108,7 @@ class RedFlagDetector:
|
|
| 95 |
matches: List[RedFlagMatch] = []
|
| 96 |
|
| 97 |
text_lower = text.lower()
|
|
|
|
| 98 |
|
| 99 |
for rule in rules:
|
| 100 |
if not rule.enabled:
|
|
@@ -102,12 +116,17 @@ class RedFlagDetector:
|
|
| 102 |
|
| 103 |
matched = False
|
| 104 |
matched_text = ""
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# 1) Keyword quick-check (cheap)
|
| 107 |
for kw in (rule.keywords or []):
|
| 108 |
if kw and kw.lower() in text_lower:
|
| 109 |
matched = True
|
| 110 |
matched_text = kw
|
|
|
|
|
|
|
| 111 |
break
|
| 112 |
|
| 113 |
# 2) Regex check (more precise)
|
|
@@ -118,10 +137,15 @@ class RedFlagDetector:
|
|
| 118 |
if m:
|
| 119 |
matched = True
|
| 120 |
matched_text = m.group(0)
|
|
|
|
|
|
|
| 121 |
except re.error:
|
| 122 |
# invalid regex; skip this rule
|
| 123 |
continue
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
if matched:
|
| 126 |
matches.append(
|
| 127 |
RedFlagMatch(
|
|
@@ -130,6 +154,20 @@ class RedFlagDetector:
|
|
| 130 |
severity=rule.severity,
|
| 131 |
description=rule.description,
|
| 132 |
matched_text=matched_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
)
|
| 135 |
|
|
@@ -161,3 +199,17 @@ class RedFlagDetector:
|
|
| 161 |
|
| 162 |
async def close(self):
|
| 163 |
await self._client.aclose()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import os
|
| 15 |
import re
|
| 16 |
import time
|
|
|
|
| 17 |
from typing import List, Dict, Any, Optional
|
| 18 |
+
|
| 19 |
import httpx
|
| 20 |
|
| 21 |
from ..models.redflag import RedFlagRule, RedFlagMatch
|
| 22 |
+
from .semantic_encoder import embed_text, cosine_similarity
|
| 23 |
|
| 24 |
|
| 25 |
class RedFlagDetector:
|
|
|
|
| 30 |
self.admin_mcp_url = admin_mcp_url or os.getenv("ADMIN_MCP_URL")
|
| 31 |
self.cache_ttl = cache_ttl
|
| 32 |
self._rules_cache: Dict[str, Dict[str, Any]] = {} # tenant_id -> {"fetched_at":ts, "rules":[...]}
|
| 33 |
+
self._rule_embeddings: Dict[str, Dict[str, List[float]]] = {}
|
| 34 |
self._client = httpx.AsyncClient(timeout=15)
|
| 35 |
|
| 36 |
async def _fetch_rules_from_supabase(self, tenant_id: str) -> List[RedFlagRule]:
|
|
|
|
| 86 |
|
| 87 |
rules = await self._fetch_rules_from_supabase(tenant_id)
|
| 88 |
self._rules_cache[tenant_id] = {"fetched_at": now, "rules": rules}
|
| 89 |
+
# Pre-compute embeddings for semantic scoring
|
| 90 |
+
embed_map: Dict[str, List[float]] = {}
|
| 91 |
+
for rule in rules:
|
| 92 |
+
try:
|
| 93 |
+
text_for_embedding = " ".join(
|
| 94 |
+
[piece for piece in [rule.description, rule.pattern] if piece]
|
| 95 |
+
).strip() or rule.id
|
| 96 |
+
embed_map[rule.id] = embed_text(text_for_embedding)
|
| 97 |
+
except Exception:
|
| 98 |
+
embed_map[rule.id] = []
|
| 99 |
+
self._rule_embeddings[tenant_id] = embed_map
|
| 100 |
return rules
|
| 101 |
|
| 102 |
async def check(self, tenant_id: str, text: str) -> List[RedFlagMatch]:
|
|
|
|
| 108 |
matches: List[RedFlagMatch] = []
|
| 109 |
|
| 110 |
text_lower = text.lower()
|
| 111 |
+
text_vector = embed_text(text)
|
| 112 |
|
| 113 |
for rule in rules:
|
| 114 |
if not rule.enabled:
|
|
|
|
| 116 |
|
| 117 |
matched = False
|
| 118 |
matched_text = ""
|
| 119 |
+
match_source = ""
|
| 120 |
+
keyword_score = 0.0
|
| 121 |
+
regex_score = 0.0
|
| 122 |
|
| 123 |
# 1) Keyword quick-check (cheap)
|
| 124 |
for kw in (rule.keywords or []):
|
| 125 |
if kw and kw.lower() in text_lower:
|
| 126 |
matched = True
|
| 127 |
matched_text = kw
|
| 128 |
+
keyword_score = 0.8
|
| 129 |
+
match_source = "keyword"
|
| 130 |
break
|
| 131 |
|
| 132 |
# 2) Regex check (more precise)
|
|
|
|
| 137 |
if m:
|
| 138 |
matched = True
|
| 139 |
matched_text = m.group(0)
|
| 140 |
+
regex_score = 1.0
|
| 141 |
+
match_source = "regex"
|
| 142 |
except re.error:
|
| 143 |
# invalid regex; skip this rule
|
| 144 |
continue
|
| 145 |
|
| 146 |
+
semantic_score = self._semantic_score(tenant_id, rule.id, text_vector)
|
| 147 |
+
confidence = max(semantic_score, keyword_score, regex_score)
|
| 148 |
+
|
| 149 |
if matched:
|
| 150 |
matches.append(
|
| 151 |
RedFlagMatch(
|
|
|
|
| 154 |
severity=rule.severity,
|
| 155 |
description=rule.description,
|
| 156 |
matched_text=matched_text,
|
| 157 |
+
confidence=round(confidence, 2),
|
| 158 |
+
explanation=self._build_explanation(rule, match_source, matched_text, confidence),
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
elif semantic_score >= 0.82:
|
| 162 |
+
matches.append(
|
| 163 |
+
RedFlagMatch(
|
| 164 |
+
rule_id=rule.id,
|
| 165 |
+
pattern=rule.pattern,
|
| 166 |
+
severity=rule.severity,
|
| 167 |
+
description=rule.description,
|
| 168 |
+
matched_text=matched_text or "",
|
| 169 |
+
confidence=round(semantic_score, 2),
|
| 170 |
+
explanation=self._build_explanation(rule, "semantic", matched_text, semantic_score),
|
| 171 |
)
|
| 172 |
)
|
| 173 |
|
|
|
|
| 199 |
|
| 200 |
async def close(self):
|
| 201 |
await self._client.aclose()
|
| 202 |
+
|
| 203 |
+
def _semantic_score(self, tenant_id: str, rule_id: str, text_vector: List[float]) -> float:
|
| 204 |
+
rule_vectors = self._rule_embeddings.get(tenant_id, {})
|
| 205 |
+
rule_vector = rule_vectors.get(rule_id)
|
| 206 |
+
if not rule_vector:
|
| 207 |
+
return 0.0
|
| 208 |
+
return cosine_similarity(rule_vector, text_vector)
|
| 209 |
+
|
| 210 |
+
@staticmethod
|
| 211 |
+
def _build_explanation(rule: RedFlagRule, source: str, matched_text: str, confidence: float) -> str:
|
| 212 |
+
base = f"Matched rule '{rule.description or rule.id}' via {source or 'heuristics'}"
|
| 213 |
+
if matched_text:
|
| 214 |
+
base += f" on span \"{matched_text}\""
|
| 215 |
+
return f"{base}. confidence={round(confidence, 2)}"
|
backend/api/services/semantic_encoder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Shared semantic encoding utilities for backend services.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
from typing import Iterable, List, Optional
|
| 9 |
+
import hashlib
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
except ImportError: # pragma: no cover - optional dependency
|
| 16 |
+
SentenceTransformer = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@lru_cache(maxsize=1)
|
| 20 |
+
def _get_model() -> Optional[SentenceTransformer]:
|
| 21 |
+
"""
|
| 22 |
+
Lazily load the MiniLM encoder once per process.
|
| 23 |
+
"""
|
| 24 |
+
if SentenceTransformer is None:
|
| 25 |
+
return None
|
| 26 |
+
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def embed_text(text: str) -> List[float]:
|
| 30 |
+
"""
|
| 31 |
+
Generate an embedding for the provided text.
|
| 32 |
+
"""
|
| 33 |
+
if not text:
|
| 34 |
+
text = ""
|
| 35 |
+
model = _get_model()
|
| 36 |
+
if model is None:
|
| 37 |
+
return _fallback_embed(text)
|
| 38 |
+
vector = model.encode(text)
|
| 39 |
+
return vector.tolist()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cosine_similarity(vec_a: Iterable[float], vec_b: Iterable[float]) -> float:
|
| 43 |
+
a = np.array(list(vec_a), dtype=float)
|
| 44 |
+
b = np.array(list(vec_b), dtype=float)
|
| 45 |
+
denom = (np.linalg.norm(a) * np.linalg.norm(b))
|
| 46 |
+
if denom == 0:
|
| 47 |
+
return 0.0
|
| 48 |
+
return float(np.dot(a, b) / denom)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _fallback_embed(text: str, dim: int = 64) -> List[float]:
|
| 52 |
+
"""
|
| 53 |
+
Deterministic hashing-based embedding used when sentence-transformers
|
| 54 |
+
is not available (e.g., during slim CI environments).
|
| 55 |
+
"""
|
| 56 |
+
vector = [0.0] * dim
|
| 57 |
+
for token in text.lower().split():
|
| 58 |
+
digest = hashlib.sha256(token.encode("utf-8")).hexdigest()
|
| 59 |
+
idx = int(digest, 16) % dim
|
| 60 |
+
vector[idx] += 1.0
|
| 61 |
+
return vector
|
| 62 |
+
|
backend/api/services/tool_scoring.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
|
| 6 |
+
from .semantic_encoder import embed_text, cosine_similarity
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _normalize(score: float) -> float:
|
| 10 |
+
return max(0.0, min(1.0, score))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ToolScoringService:
|
| 15 |
+
"""
|
| 16 |
+
Heuristic + semantic tool fitness scoring.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
_domain_prompts: Dict[str, str] = field(default_factory=lambda: {
|
| 20 |
+
"rag": "internal company policy, handbook, corporate procedure, proprietary",
|
| 21 |
+
"web": "latest external news, public web search, trending topics, live data",
|
| 22 |
+
"llm": "casual chit chat, brainstorming, creative writing, general knowledge"
|
| 23 |
+
})
|
| 24 |
+
_domain_vectors: Dict[str, List[float]] = field(init=False)
|
| 25 |
+
|
| 26 |
+
def __post_init__(self):
|
| 27 |
+
self._domain_vectors = {
|
| 28 |
+
name: embed_text(prompt)
|
| 29 |
+
for name, prompt in self._domain_prompts.items()
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
def score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]:
|
| 33 |
+
embedding = embed_text(message)
|
| 34 |
+
rag_sem = cosine_similarity(embedding, self._domain_vectors["rag"])
|
| 35 |
+
web_sem = cosine_similarity(embedding, self._domain_vectors["web"])
|
| 36 |
+
llm_sem = cosine_similarity(embedding, self._domain_vectors["llm"])
|
| 37 |
+
|
| 38 |
+
rag_signal = 0.4 * rag_sem + 0.4 * (1 if rag_results else 0) + 0.2 * (1 if intent == "rag" else 0)
|
| 39 |
+
web_signal = 0.5 * web_sem + 0.3 * (1 if intent == "web" else 0) + 0.2 * self._freshness_signal(message)
|
| 40 |
+
llm_signal = 0.6 * llm_sem + 0.4 * (1 if intent == "general" else 0)
|
| 41 |
+
|
| 42 |
+
return {
|
| 43 |
+
"rag_fitness": round(_normalize(rag_signal), 3),
|
| 44 |
+
"web_fitness": round(_normalize(web_signal), 3),
|
| 45 |
+
"llm_only": round(_normalize(llm_signal), 3)
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def _freshness_signal(message: str) -> float:
|
| 50 |
+
tokens = ("news", "today", "latest", "current", "breaking", "update", "recent", "now")
|
| 51 |
+
msg = message.lower()
|
| 52 |
+
hits = sum(1 for token in tokens if token in msg)
|
| 53 |
+
return min(1.0, hits / 3.0)
|
| 54 |
+
|
backend/api/services/tool_selector.py
CHANGED
|
@@ -10,6 +10,10 @@ class ToolSelector:
|
|
| 10 |
|
| 11 |
async def select(self, intent: str, text: str, ctx):
|
| 12 |
msg = text.lower().strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# ---------------------------------
|
| 15 |
# 1. Detect ADMIN RULES FIRST
|
|
@@ -35,9 +39,9 @@ class ToolSelector:
|
|
| 35 |
r"company", r"internal", r"documentation", r"our ", r"your ",
|
| 36 |
r"knowledge base", r"private", r"internal docs", r"corporate"
|
| 37 |
]
|
| 38 |
-
if rag_has_data or any(re.search(p, msg) for p in rag_patterns):
|
| 39 |
needs_rag = True
|
| 40 |
-
if
|
| 41 |
steps.append(step("rag", {"query": text}))
|
| 42 |
|
| 43 |
# ---------------------------------
|
|
@@ -48,7 +52,7 @@ class ToolSelector:
|
|
| 48 |
r"tell me about ", r"define ", r"explain ",
|
| 49 |
r"history of ", r"information about", r"details about"
|
| 50 |
]
|
| 51 |
-
if any(re.search(p, msg) for p in fact_patterns):
|
| 52 |
needs_web = True
|
| 53 |
steps.append(step("web", {"query": text}))
|
| 54 |
|
|
@@ -88,6 +92,7 @@ TOOLS:
|
|
| 88 |
Current context:
|
| 89 |
- RAG available: {rag_has_data}
|
| 90 |
- User message: "{text}"
|
|
|
|
| 91 |
|
| 92 |
Determine which tools are needed. You can select:
|
| 93 |
- Just LLM (simple questions)
|
|
@@ -140,7 +145,7 @@ Only return the JSON array. Do not include markdown formatting.
|
|
| 140 |
|
| 141 |
# Build reason string showing the tool sequence
|
| 142 |
tool_names = [s["tool"] for s in steps]
|
| 143 |
-
reason = f"multi-tool plan: {' → '.join(tool_names)}"
|
| 144 |
|
| 145 |
return _multi_step(steps, reason)
|
| 146 |
|
|
|
|
| 10 |
|
| 11 |
async def select(self, intent: str, text: str, ctx):
|
| 12 |
msg = text.lower().strip()
|
| 13 |
+
tool_scores = ctx.get("tool_scores", {})
|
| 14 |
+
rag_score = tool_scores.get("rag_fitness", 0.0)
|
| 15 |
+
web_score = tool_scores.get("web_fitness", 0.0)
|
| 16 |
+
llm_score = tool_scores.get("llm_only", 0.0)
|
| 17 |
|
| 18 |
# ---------------------------------
|
| 19 |
# 1. Detect ADMIN RULES FIRST
|
|
|
|
| 39 |
r"company", r"internal", r"documentation", r"our ", r"your ",
|
| 40 |
r"knowledge base", r"private", r"internal docs", r"corporate"
|
| 41 |
]
|
| 42 |
+
if rag_has_data or rag_score >= 0.55 or any(re.search(p, msg) for p in rag_patterns):
|
| 43 |
needs_rag = True
|
| 44 |
+
if not any(s["tool"] == "rag" for s in steps):
|
| 45 |
steps.append(step("rag", {"query": text}))
|
| 46 |
|
| 47 |
# ---------------------------------
|
|
|
|
| 52 |
r"tell me about ", r"define ", r"explain ",
|
| 53 |
r"history of ", r"information about", r"details about"
|
| 54 |
]
|
| 55 |
+
if web_score >= 0.55 or any(re.search(p, msg) for p in fact_patterns):
|
| 56 |
needs_web = True
|
| 57 |
steps.append(step("web", {"query": text}))
|
| 58 |
|
|
|
|
| 92 |
Current context:
|
| 93 |
- RAG available: {rag_has_data}
|
| 94 |
- User message: "{text}"
|
| 95 |
+
- Tool scores: {json.dumps(tool_scores)}
|
| 96 |
|
| 97 |
Determine which tools are needed. You can select:
|
| 98 |
- Just LLM (simple questions)
|
|
|
|
| 145 |
|
| 146 |
# Build reason string showing the tool sequence
|
| 147 |
tool_names = [s["tool"] for s in steps]
|
| 148 |
+
reason = f"multi-tool plan: {' → '.join(tool_names)} | scores={tool_scores}"
|
| 149 |
|
| 150 |
return _multi_step(steps, reason)
|
| 151 |
|
backend/mcp_servers/database.py
CHANGED
|
@@ -132,7 +132,7 @@ def insert_document_chunks(tenant_id: str, text: str, embedding: list):
|
|
| 132 |
raise
|
| 133 |
|
| 134 |
|
| 135 |
-
def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[str]:
|
| 136 |
"""
|
| 137 |
Perform semantic vector search using pgvector.
|
| 138 |
"""
|
|
@@ -158,7 +158,15 @@ def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[str]:
|
|
| 158 |
cur.close()
|
| 159 |
conn.close()
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
except Exception as e:
|
| 164 |
print("DB SEARCH ERROR:", e)
|
|
|
|
| 132 |
raise
|
| 133 |
|
| 134 |
|
| 135 |
+
def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[str, Any]]:
|
| 136 |
"""
|
| 137 |
Perform semantic vector search using pgvector.
|
| 138 |
"""
|
|
|
|
| 158 |
cur.close()
|
| 159 |
conn.close()
|
| 160 |
|
| 161 |
+
results: List[Dict[str, Any]] = []
|
| 162 |
+
for row in rows:
|
| 163 |
+
results.append(
|
| 164 |
+
{
|
| 165 |
+
"text": row["chunk_text"],
|
| 166 |
+
"similarity": float(row.get("similarity", 0.0)),
|
| 167 |
+
}
|
| 168 |
+
)
|
| 169 |
+
return results
|
| 170 |
|
| 171 |
except Exception as e:
|
| 172 |
print("DB SEARCH ERROR:", e)
|
backend/mcp_servers/rag_server.py
CHANGED
|
@@ -11,6 +11,8 @@ import os
|
|
| 11 |
current_dir = os.path.dirname(__file__)
|
| 12 |
sys.path.insert(0, current_dir)
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from embeddings import embed_text
|
| 15 |
from database import insert_document_chunks, search_vectors
|
| 16 |
from models.rag import IngestRequest, SearchRequest
|
|
@@ -47,11 +49,42 @@ async def ingest(req: IngestRequest):
|
|
| 47 |
return {"status": "ok"}
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
@rag_app.post("/search")
|
| 51 |
async def search(req: SearchRequest):
|
| 52 |
vector = embed_text(req.query)
|
| 53 |
results = db_search(req.tenant_id, vector)
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
if __name__ == "__main__":
|
|
|
|
| 11 |
current_dir = os.path.dirname(__file__)
|
| 12 |
sys.path.insert(0, current_dir)
|
| 13 |
|
| 14 |
+
from typing import Any, Dict, List
|
| 15 |
+
|
| 16 |
from embeddings import embed_text
|
| 17 |
from database import insert_document_chunks, search_vectors
|
| 18 |
from models.rag import IngestRequest, SearchRequest
|
|
|
|
| 49 |
return {"status": "ok"}
|
| 50 |
|
| 51 |
|
| 52 |
+
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
| 53 |
+
import math
|
| 54 |
+
|
| 55 |
+
if not vec_a or not vec_b:
|
| 56 |
+
return 0.0
|
| 57 |
+
numerator = sum(a * b for a, b in zip(vec_a, vec_b))
|
| 58 |
+
denom = math.sqrt(sum(a * a for a in vec_a)) * math.sqrt(sum(b * b for b in vec_b))
|
| 59 |
+
if denom == 0:
|
| 60 |
+
return 0.0
|
| 61 |
+
return numerator / denom
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def rank_chunks(chunks: List[Dict[str, Any]], query_embedding: List[float]):
|
| 65 |
+
ranked = []
|
| 66 |
+
for chunk in chunks:
|
| 67 |
+
chunk_vector = embed_text(chunk.get("text", ""))
|
| 68 |
+
relevance = cosine_similarity(chunk_vector, query_embedding)
|
| 69 |
+
chunk["relevance"] = relevance
|
| 70 |
+
ranked.append(chunk)
|
| 71 |
+
return sorted(ranked, key=lambda x: x["relevance"], reverse=True)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
@rag_app.post("/search")
|
| 75 |
async def search(req: SearchRequest):
|
| 76 |
vector = embed_text(req.query)
|
| 77 |
results = db_search(req.tenant_id, vector)
|
| 78 |
+
ranked = rank_chunks(results, vector)
|
| 79 |
+
filtered = [chunk for chunk in ranked if chunk["relevance"] >= 0.55][:3]
|
| 80 |
+
return {
|
| 81 |
+
"results": filtered,
|
| 82 |
+
"metadata": {
|
| 83 |
+
"total_retrieved": len(results),
|
| 84 |
+
"returned": len(filtered),
|
| 85 |
+
"threshold": 0.55
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
|
| 89 |
|
| 90 |
if __name__ == "__main__":
|
backend/tests/test_agent_orchestrator.py
CHANGED
|
@@ -63,6 +63,11 @@ class FakeMCP:
|
|
| 63 |
return {"action": "allow"}
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# ---------------------------
|
| 67 |
# Patch orchestrator to use fake MCP + fake redflag
|
| 68 |
# ---------------------------
|
|
@@ -135,6 +140,7 @@ async def test_block_on_redflag(orchestrator):
|
|
| 135 |
assert resp.decision.action == "block"
|
| 136 |
assert resp.decision.tool == "admin"
|
| 137 |
assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"]
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
@pytest.mark.asyncio
|
|
@@ -158,9 +164,10 @@ async def test_rag_tool_path(orchestrator, monkeypatch):
|
|
| 158 |
|
| 159 |
resp = await orchestrator.handle(req)
|
| 160 |
|
| 161 |
-
assert resp.decision.
|
| 162 |
-
assert "
|
| 163 |
assert resp.text == "MOCK_ANSWER"
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
@pytest.mark.asyncio
|
|
@@ -184,8 +191,10 @@ async def test_web_tool_path(orchestrator, monkeypatch):
|
|
| 184 |
|
| 185 |
resp = await orchestrator.handle(req)
|
| 186 |
|
| 187 |
-
assert resp.decision.
|
|
|
|
| 188 |
assert resp.text == "MOCK_ANSWER"
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
@pytest.mark.asyncio
|
|
@@ -218,3 +227,4 @@ async def test_default_llm_path(orchestrator, monkeypatch):
|
|
| 218 |
assert resp.decision.action == "respond"
|
| 219 |
assert resp.decision.tool is None
|
| 220 |
assert resp.text == "MOCK_ANSWER"
|
|
|
|
|
|
| 63 |
return {"action": "allow"}
|
| 64 |
|
| 65 |
|
| 66 |
+
def assert_trace_has_step(resp, step_name):
|
| 67 |
+
assert resp.reasoning_trace, "reasoning trace missing"
|
| 68 |
+
assert any(entry.get("step") == step_name for entry in resp.reasoning_trace), f"{step_name} missing"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
# ---------------------------
|
| 72 |
# Patch orchestrator to use fake MCP + fake redflag
|
| 73 |
# ---------------------------
|
|
|
|
| 140 |
assert resp.decision.action == "block"
|
| 141 |
assert resp.decision.tool == "admin"
|
| 142 |
assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"]
|
| 143 |
+
assert_trace_has_step(resp, "redflag_check")
|
| 144 |
|
| 145 |
|
| 146 |
@pytest.mark.asyncio
|
|
|
|
| 164 |
|
| 165 |
resp = await orchestrator.handle(req)
|
| 166 |
|
| 167 |
+
assert resp.decision.action == "multi_step"
|
| 168 |
+
assert any(trace["tool"] == "rag" for trace in resp.tool_traces if trace.get("tool") == "rag")
|
| 169 |
assert resp.text == "MOCK_ANSWER"
|
| 170 |
+
assert_trace_has_step(resp, "tool_selection")
|
| 171 |
|
| 172 |
|
| 173 |
@pytest.mark.asyncio
|
|
|
|
| 191 |
|
| 192 |
resp = await orchestrator.handle(req)
|
| 193 |
|
| 194 |
+
assert resp.decision.action == "multi_step"
|
| 195 |
+
assert any(trace["tool"] == "web" for trace in resp.tool_traces if trace.get("tool") == "web")
|
| 196 |
assert resp.text == "MOCK_ANSWER"
|
| 197 |
+
assert_trace_has_step(resp, "tool_selection")
|
| 198 |
|
| 199 |
|
| 200 |
@pytest.mark.asyncio
|
|
|
|
| 227 |
assert resp.decision.action == "respond"
|
| 228 |
assert resp.decision.tool is None
|
| 229 |
assert resp.text == "MOCK_ANSWER"
|
| 230 |
+
assert_trace_has_step(resp, "intent_detection")
|
frontend/.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
| 2 |
+
|
| 3 |
+
# dependencies
|
| 4 |
+
/node_modules
|
| 5 |
+
/.pnp
|
| 6 |
+
.pnp.*
|
| 7 |
+
.yarn/*
|
| 8 |
+
!.yarn/patches
|
| 9 |
+
!.yarn/plugins
|
| 10 |
+
!.yarn/releases
|
| 11 |
+
!.yarn/versions
|
| 12 |
+
|
| 13 |
+
# testing
|
| 14 |
+
/coverage
|
| 15 |
+
|
| 16 |
+
# next.js
|
| 17 |
+
/.next/
|
| 18 |
+
/out/
|
| 19 |
+
|
| 20 |
+
# production
|
| 21 |
+
/build
|
| 22 |
+
|
| 23 |
+
# misc
|
| 24 |
+
.DS_Store
|
| 25 |
+
*.pem
|
| 26 |
+
|
| 27 |
+
# debug
|
| 28 |
+
npm-debug.log*
|
| 29 |
+
yarn-debug.log*
|
| 30 |
+
yarn-error.log*
|
| 31 |
+
.pnpm-debug.log*
|
| 32 |
+
|
| 33 |
+
# env files (can opt-in for committing if needed)
|
| 34 |
+
.env*
|
| 35 |
+
|
| 36 |
+
# vercel
|
| 37 |
+
.vercel
|
| 38 |
+
|
| 39 |
+
# typescript
|
| 40 |
+
*.tsbuildinfo
|
| 41 |
+
next-env.d.ts
|
frontend/README.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
|
| 2 |
+
|
| 3 |
+
## Getting Started
|
| 4 |
+
|
| 5 |
+
First, run the development server:
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
npm run dev
|
| 9 |
+
# or
|
| 10 |
+
yarn dev
|
| 11 |
+
# or
|
| 12 |
+
pnpm dev
|
| 13 |
+
# or
|
| 14 |
+
bun dev
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
| 18 |
+
|
| 19 |
+
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
|
| 20 |
+
|
| 21 |
+
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
|
| 22 |
+
|
| 23 |
+
## Learn More
|
| 24 |
+
|
| 25 |
+
To learn more about Next.js, take a look at the following resources:
|
| 26 |
+
|
| 27 |
+
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
| 28 |
+
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
| 29 |
+
|
| 30 |
+
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
|
| 31 |
+
|
| 32 |
+
## Deploy on Vercel
|
| 33 |
+
|
| 34 |
+
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
|
| 35 |
+
|
| 36 |
+
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
|
frontend/app/favicon.ico
ADDED
|
|
frontend/app/globals.css
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@import "tailwindcss";
|
| 2 |
+
|
| 3 |
+
:root {
|
| 4 |
+
--background: #ffffff;
|
| 5 |
+
--foreground: #171717;
|
| 6 |
+
}
|
| 7 |
+
|
| 8 |
+
@theme inline {
|
| 9 |
+
--color-background: var(--background);
|
| 10 |
+
--color-foreground: var(--foreground);
|
| 11 |
+
--font-sans: var(--font-geist-sans);
|
| 12 |
+
--font-mono: var(--font-geist-mono);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
@media (prefers-color-scheme: dark) {
|
| 16 |
+
:root {
|
| 17 |
+
--background: #0a0a0a;
|
| 18 |
+
--foreground: #ededed;
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
body {
|
| 23 |
+
background: var(--background);
|
| 24 |
+
color: var(--foreground);
|
| 25 |
+
font-family: Arial, Helvetica, sans-serif;
|
| 26 |
+
}
|
frontend/app/layout.tsx
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { Metadata } from "next";
|
| 2 |
+
import { Geist, Geist_Mono } from "next/font/google";
|
| 3 |
+
import "./globals.css";
|
| 4 |
+
|
| 5 |
+
const geistSans = Geist({
|
| 6 |
+
variable: "--font-geist-sans",
|
| 7 |
+
subsets: ["latin"],
|
| 8 |
+
});
|
| 9 |
+
|
| 10 |
+
const geistMono = Geist_Mono({
|
| 11 |
+
variable: "--font-geist-mono",
|
| 12 |
+
subsets: ["latin"],
|
| 13 |
+
});
|
| 14 |
+
|
| 15 |
+
export const metadata: Metadata = {
|
| 16 |
+
title: "Create Next App",
|
| 17 |
+
description: "Generated by create next app",
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
export default function RootLayout({
|
| 21 |
+
children,
|
| 22 |
+
}: Readonly<{
|
| 23 |
+
children: React.ReactNode;
|
| 24 |
+
}>) {
|
| 25 |
+
return (
|
| 26 |
+
<html lang="en">
|
| 27 |
+
<body
|
| 28 |
+
className={`${geistSans.variable} ${geistMono.variable} antialiased`}
|
| 29 |
+
>
|
| 30 |
+
{children}
|
| 31 |
+
</body>
|
| 32 |
+
</html>
|
| 33 |
+
);
|
| 34 |
+
}
|
frontend/app/page.tsx
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import Image from "next/image";
|
| 2 |
+
|
| 3 |
+
export default function Home() {
|
| 4 |
+
return (
|
| 5 |
+
<div className="flex min-h-screen items-center justify-center bg-zinc-50 font-sans dark:bg-black">
|
| 6 |
+
<main className="flex min-h-screen w-full max-w-3xl flex-col items-center justify-between py-32 px-16 bg-white dark:bg-black sm:items-start">
|
| 7 |
+
<Image
|
| 8 |
+
className="dark:invert"
|
| 9 |
+
src="/next.svg"
|
| 10 |
+
alt="Next.js logo"
|
| 11 |
+
width={100}
|
| 12 |
+
height={20}
|
| 13 |
+
priority
|
| 14 |
+
/>
|
| 15 |
+
<div className="flex flex-col items-center gap-6 text-center sm:items-start sm:text-left">
|
| 16 |
+
<h1 className="max-w-xs text-3xl font-semibold leading-10 tracking-tight text-black dark:text-zinc-50">
|
| 17 |
+
To get started, edit the page.tsx file.
|
| 18 |
+
</h1>
|
| 19 |
+
<p className="max-w-md text-lg leading-8 text-zinc-600 dark:text-zinc-400">
|
| 20 |
+
Looking for a starting point or more instructions? Head over to{" "}
|
| 21 |
+
<a
|
| 22 |
+
href="https://vercel.com/templates?framework=next.js&utm_source=create-next-app&utm_medium=appdir-template-tw&utm_campaign=create-next-app"
|
| 23 |
+
className="font-medium text-zinc-950 dark:text-zinc-50"
|
| 24 |
+
>
|
| 25 |
+
Templates
|
| 26 |
+
</a>{" "}
|
| 27 |
+
or the{" "}
|
| 28 |
+
<a
|
| 29 |
+
href="https://nextjs.org/learn?utm_source=create-next-app&utm_medium=appdir-template-tw&utm_campaign=create-next-app"
|
| 30 |
+
className="font-medium text-zinc-950 dark:text-zinc-50"
|
| 31 |
+
>
|
| 32 |
+
Learning
|
| 33 |
+
</a>{" "}
|
| 34 |
+
center.
|
| 35 |
+
</p>
|
| 36 |
+
</div>
|
| 37 |
+
<div className="flex flex-col gap-4 text-base font-medium sm:flex-row">
|
| 38 |
+
<a
|
| 39 |
+
className="flex h-12 w-full items-center justify-center gap-2 rounded-full bg-foreground px-5 text-background transition-colors hover:bg-[#383838] dark:hover:bg-[#ccc] md:w-[158px]"
|
| 40 |
+
href="https://vercel.com/new?utm_source=create-next-app&utm_medium=appdir-template-tw&utm_campaign=create-next-app"
|
| 41 |
+
target="_blank"
|
| 42 |
+
rel="noopener noreferrer"
|
| 43 |
+
>
|
| 44 |
+
<Image
|
| 45 |
+
className="dark:invert"
|
| 46 |
+
src="/vercel.svg"
|
| 47 |
+
alt="Vercel logomark"
|
| 48 |
+
width={16}
|
| 49 |
+
height={16}
|
| 50 |
+
/>
|
| 51 |
+
Deploy Now
|
| 52 |
+
</a>
|
| 53 |
+
<a
|
| 54 |
+
className="flex h-12 w-full items-center justify-center rounded-full border border-solid border-black/[.08] px-5 transition-colors hover:border-transparent hover:bg-black/[.04] dark:border-white/[.145] dark:hover:bg-[#1a1a1a] md:w-[158px]"
|
| 55 |
+
href="https://nextjs.org/docs?utm_source=create-next-app&utm_medium=appdir-template-tw&utm_campaign=create-next-app"
|
| 56 |
+
target="_blank"
|
| 57 |
+
rel="noopener noreferrer"
|
| 58 |
+
>
|
| 59 |
+
Documentation
|
| 60 |
+
</a>
|
| 61 |
+
</div>
|
| 62 |
+
</main>
|
| 63 |
+
</div>
|
| 64 |
+
);
|
| 65 |
+
}
|
frontend/eslint.config.mjs
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { defineConfig, globalIgnores } from "eslint/config";
|
| 2 |
+
import nextVitals from "eslint-config-next/core-web-vitals";
|
| 3 |
+
import nextTs from "eslint-config-next/typescript";
|
| 4 |
+
|
| 5 |
+
const eslintConfig = defineConfig([
|
| 6 |
+
...nextVitals,
|
| 7 |
+
...nextTs,
|
| 8 |
+
// Override default ignores of eslint-config-next.
|
| 9 |
+
globalIgnores([
|
| 10 |
+
// Default ignores of eslint-config-next:
|
| 11 |
+
".next/**",
|
| 12 |
+
"out/**",
|
| 13 |
+
"build/**",
|
| 14 |
+
"next-env.d.ts",
|
| 15 |
+
]),
|
| 16 |
+
]);
|
| 17 |
+
|
| 18 |
+
export default eslintConfig;
|
frontend/next.config.ts
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import type { NextConfig } from "next";
|
| 2 |
+
|
| 3 |
+
const nextConfig: NextConfig = {
|
| 4 |
+
/* config options here */
|
| 5 |
+
};
|
| 6 |
+
|
| 7 |
+
export default nextConfig;
|
frontend/package-lock.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend/package.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "frontend",
|
| 3 |
+
"version": "0.1.0",
|
| 4 |
+
"private": true,
|
| 5 |
+
"scripts": {
|
| 6 |
+
"dev": "next dev",
|
| 7 |
+
"build": "next build",
|
| 8 |
+
"start": "next start",
|
| 9 |
+
"lint": "eslint"
|
| 10 |
+
},
|
| 11 |
+
"dependencies": {
|
| 12 |
+
"next": "16.0.3",
|
| 13 |
+
"react": "19.2.0",
|
| 14 |
+
"react-dom": "19.2.0"
|
| 15 |
+
},
|
| 16 |
+
"devDependencies": {
|
| 17 |
+
"@tailwindcss/postcss": "^4",
|
| 18 |
+
"@types/node": "^20",
|
| 19 |
+
"@types/react": "^19",
|
| 20 |
+
"@types/react-dom": "^19",
|
| 21 |
+
"eslint": "^9",
|
| 22 |
+
"eslint-config-next": "16.0.3",
|
| 23 |
+
"tailwindcss": "^4",
|
| 24 |
+
"typescript": "^5"
|
| 25 |
+
}
|
| 26 |
+
}
|
frontend/postcss.config.mjs
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
const config = {
|
| 2 |
+
plugins: {
|
| 3 |
+
"@tailwindcss/postcss": {},
|
| 4 |
+
},
|
| 5 |
+
};
|
| 6 |
+
|
| 7 |
+
export default config;
|
frontend/public/file.svg
ADDED
|
|
frontend/public/globe.svg
ADDED
|
|
frontend/public/next.svg
ADDED
|
|
frontend/public/vercel.svg
ADDED
|
|
frontend/public/window.svg
ADDED
|
|
frontend/tsconfig.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compilerOptions": {
|
| 3 |
+
"target": "ES2017",
|
| 4 |
+
"lib": ["dom", "dom.iterable", "esnext"],
|
| 5 |
+
"allowJs": true,
|
| 6 |
+
"skipLibCheck": true,
|
| 7 |
+
"strict": true,
|
| 8 |
+
"noEmit": true,
|
| 9 |
+
"esModuleInterop": true,
|
| 10 |
+
"module": "esnext",
|
| 11 |
+
"moduleResolution": "bundler",
|
| 12 |
+
"resolveJsonModule": true,
|
| 13 |
+
"isolatedModules": true,
|
| 14 |
+
"jsx": "react-jsx",
|
| 15 |
+
"incremental": true,
|
| 16 |
+
"plugins": [
|
| 17 |
+
{
|
| 18 |
+
"name": "next"
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"paths": {
|
| 22 |
+
"@/*": ["./*"]
|
| 23 |
+
}
|
| 24 |
+
},
|
| 25 |
+
"include": [
|
| 26 |
+
"next-env.d.ts",
|
| 27 |
+
"**/*.ts",
|
| 28 |
+
"**/*.tsx",
|
| 29 |
+
".next/types/**/*.ts",
|
| 30 |
+
".next/dev/types/**/*.ts",
|
| 31 |
+
"**/*.mts"
|
| 32 |
+
],
|
| 33 |
+
"exclude": ["node_modules"]
|
| 34 |
+
}
|
test_all.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Single-file test suite for IntegraChat backend (unit + integration + simulation).
|
| 3 |
+
This version aligns with the current backend API surface.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Dict
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
from fastapi.testclient import TestClient
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Ensure backend package is importable
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 21 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 22 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 23 |
+
backend_path = PROJECT_ROOT / "backend"
|
| 24 |
+
if str(backend_path) not in sys.path:
|
| 25 |
+
sys.path.insert(0, str(backend_path))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Shared fixtures
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
@pytest.fixture(autouse=True, scope="session")
|
| 33 |
+
def set_test_env():
|
| 34 |
+
os.environ.setdefault("RAG_MCP_URL", "http://mock-rag")
|
| 35 |
+
os.environ.setdefault("WEB_MCP_URL", "http://mock-web")
|
| 36 |
+
os.environ.setdefault("ADMIN_MCP_URL", "http://mock-admin")
|
| 37 |
+
os.environ.setdefault("OLLAMA_URL", "http://localhost:11434")
|
| 38 |
+
os.environ.setdefault("OLLAMA_MODEL", "llama3")
|
| 39 |
+
os.environ.setdefault("LLM_BACKEND", "ollama")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@pytest.fixture
|
| 43 |
+
def mock_backend_dependencies(monkeypatch):
|
| 44 |
+
print(">> applying backend dependency patches for tests")
|
| 45 |
+
"""Patch MCP client calls and red-flag detector for deterministic tests."""
|
| 46 |
+
from backend.api.models.redflag import RedFlagMatch
|
| 47 |
+
from backend.api.services.tool_scoring import ToolScoringService
|
| 48 |
+
import types
|
| 49 |
+
|
| 50 |
+
async def fake_call_rag(self, tenant_id: str, query: str) -> Dict:
|
| 51 |
+
return {
|
| 52 |
+
"results": [
|
| 53 |
+
{"text": "HR policy includes onboarding, leave rules.", "relevance": 0.92},
|
| 54 |
+
{"text": "General company announcement", "relevance": 0.42}
|
| 55 |
+
],
|
| 56 |
+
"metadata": {"total_retrieved": 2, "returned": 2, "threshold": 0.55}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
async def fake_call_web(self, tenant_id: str, query: str) -> Dict:
|
| 60 |
+
return {
|
| 61 |
+
"results": [
|
| 62 |
+
{"title": "Latest inflation update", "snippet": "Inflation is 3.2%", "url": "https://example.com"},
|
| 63 |
+
{"title": "Global news", "snippet": "Market highlights", "url": "https://news.example.com"}
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
async def fake_call_admin(self, tenant_id: str, query: str) -> Dict:
|
| 68 |
+
return {"status": "ok", "tenant_id": tenant_id, "query": query}
|
| 69 |
+
|
| 70 |
+
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_rag", fake_call_rag)
|
| 71 |
+
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_web", fake_call_web)
|
| 72 |
+
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_admin", fake_call_admin)
|
| 73 |
+
|
| 74 |
+
async def fake_redflag_check(self, tenant_id: str, text: str) -> List[RedFlagMatch]:
|
| 75 |
+
if "delete" in text.lower():
|
| 76 |
+
return [
|
| 77 |
+
RedFlagMatch(
|
| 78 |
+
rule_id="1",
|
| 79 |
+
pattern="delete",
|
| 80 |
+
severity="high",
|
| 81 |
+
description="Deletion request",
|
| 82 |
+
matched_text="delete",
|
| 83 |
+
confidence=0.9,
|
| 84 |
+
explanation="Matched on keyword 'delete'"
|
| 85 |
+
)
|
| 86 |
+
]
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
async def fake_notify(self, tenant_id, violations, source_payload=None):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.check", fake_redflag_check)
|
| 93 |
+
monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.notify_admin", fake_notify)
|
| 94 |
+
|
| 95 |
+
def fake_score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]:
|
| 96 |
+
return {"rag_fitness": 0.82, "web_fitness": 0.78, "llm_only": 0.25}
|
| 97 |
+
|
| 98 |
+
monkeypatch.setattr(ToolScoringService, "score", fake_score)
|
| 99 |
+
|
| 100 |
+
# Ensure already-instantiated orchestrator uses the same patches
|
| 101 |
+
from backend.api.routes import agent as agent_routes
|
| 102 |
+
|
| 103 |
+
agent_routes.orchestrator.mcp.call_rag = types.MethodType(fake_call_rag, agent_routes.orchestrator.mcp)
|
| 104 |
+
agent_routes.orchestrator.mcp.call_web = types.MethodType(fake_call_web, agent_routes.orchestrator.mcp)
|
| 105 |
+
agent_routes.orchestrator.mcp.call_admin = types.MethodType(fake_call_admin, agent_routes.orchestrator.mcp)
|
| 106 |
+
agent_routes.orchestrator.redflag.check = types.MethodType(fake_redflag_check, agent_routes.orchestrator.redflag)
|
| 107 |
+
agent_routes.orchestrator.redflag.notify_admin = types.MethodType(fake_notify, agent_routes.orchestrator.redflag)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@pytest.fixture
|
| 111 |
+
def api_client(mock_backend_dependencies):
|
| 112 |
+
from backend.api.main import app
|
| 113 |
+
return TestClient(app)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ---------------------------------------------------------------------------
|
| 117 |
+
# Unit tests
|
| 118 |
+
# ---------------------------------------------------------------------------
|
| 119 |
+
|
| 120 |
+
@pytest.mark.asyncio
|
| 121 |
+
async def test_redflag_detector():
|
| 122 |
+
import time
|
| 123 |
+
from backend.api.services.redflag_detector import RedFlagDetector
|
| 124 |
+
from backend.api.models.redflag import RedFlagRule
|
| 125 |
+
from backend.api.services.semantic_encoder import embed_text
|
| 126 |
+
|
| 127 |
+
detector = RedFlagDetector(supabase_url="http://fake", supabase_key="fake")
|
| 128 |
+
rule = RedFlagRule(
|
| 129 |
+
id="rule-salary",
|
| 130 |
+
pattern="salary",
|
| 131 |
+
description="Salary access",
|
| 132 |
+
severity="high",
|
| 133 |
+
source="test",
|
| 134 |
+
enabled=True,
|
| 135 |
+
keywords=["salary"]
|
| 136 |
+
)
|
| 137 |
+
detector._rules_cache["tenant-x"] = {"fetched_at": int(time.time()), "rules": [rule]}
|
| 138 |
+
detector._rule_embeddings["tenant-x"] = {rule.id: embed_text("salary access")}
|
| 139 |
+
|
| 140 |
+
matches = await detector.check("tenant-x", "Show me employee salary details")
|
| 141 |
+
|
| 142 |
+
assert matches
|
| 143 |
+
assert matches[0].matched_text.lower() == "salary"
|
| 144 |
+
assert matches[0].confidence is not None
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_tool_scoring():
|
| 148 |
+
from backend.api.services.tool_scoring import ToolScoringService
|
| 149 |
+
|
| 150 |
+
scorer = ToolScoringService()
|
| 151 |
+
scores = scorer.score("What is inflation today?", intent="web", rag_results=[])
|
| 152 |
+
|
| 153 |
+
assert set(scores.keys()) == {"rag_fitness", "web_fitness", "llm_only"}
|
| 154 |
+
assert scores["web_fitness"] > scores["rag_fitness"]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@pytest.mark.asyncio
|
| 158 |
+
async def test_tool_selector():
|
| 159 |
+
from backend.api.services.tool_selector import ToolSelector
|
| 160 |
+
|
| 161 |
+
selector = ToolSelector()
|
| 162 |
+
decision = await selector.select(
|
| 163 |
+
intent="rag",
|
| 164 |
+
text="Tell me HR policy and compare with external news",
|
| 165 |
+
ctx={"rag_results": [{"text": "Policy"}], "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.8}}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
steps = decision.tool_input["steps"]
|
| 169 |
+
assert steps[0]["tool"] == "rag"
|
| 170 |
+
assert any(step["tool"] == "web" for step in steps)
|
| 171 |
+
assert steps[-1]["tool"] == "llm"
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def test_reasoning_trace_via_response(api_client):
|
| 175 |
+
payload = {"tenant_id": "tenant1", "message": "Summarize our HR policies"}
|
| 176 |
+
res = api_client.post("/agent/message", json=payload)
|
| 177 |
+
data = res.json()
|
| 178 |
+
|
| 179 |
+
assert data["reasoning_trace"]
|
| 180 |
+
step_names = [entry["step"] for entry in data["reasoning_trace"]]
|
| 181 |
+
assert "intent_detection" in step_names
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
# Integration tests
|
| 186 |
+
# ---------------------------------------------------------------------------
|
| 187 |
+
|
| 188 |
+
def test_full_agent_pipeline(api_client):
|
| 189 |
+
payload = {"tenant_id": "tenant123", "message": "What are our HR policies and latest updates?"}
|
| 190 |
+
response = api_client.post("/agent/message", json=payload)
|
| 191 |
+
data = response.json()
|
| 192 |
+
|
| 193 |
+
assert data["text"]
|
| 194 |
+
assert len(data["reasoning_trace"]) >= 3
|
| 195 |
+
|
| 196 |
+
rag_steps = [step for step in data["reasoning_trace"] if step.get("tool") == "rag"]
|
| 197 |
+
assert rag_steps, "expected rag tool execution in reasoning trace"
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_parallel_execution_detected(api_client):
|
| 201 |
+
payload = {"tenant_id": "t1", "message": "Summarize HR policies and latest news updates"}
|
| 202 |
+
response = api_client.post("/agent/message", json=payload)
|
| 203 |
+
data = response.json()
|
| 204 |
+
|
| 205 |
+
tools_used = {trace.get("tool") for trace in data["tool_traces"] if trace.get("tool")}
|
| 206 |
+
assert "rag" in tools_used and "web" in tools_used
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ---------------------------------------------------------------------------
|
| 210 |
+
# Simulation tests
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
|
| 213 |
+
SIM_QUERIES = [
|
| 214 |
+
"What is the inflation rate today?",
|
| 215 |
+
"Summarize our HR policies",
|
| 216 |
+
"Delete all records",
|
| 217 |
+
"Explain our refund policy",
|
| 218 |
+
"How many employees are in the company?"
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@pytest.mark.parametrize("message", SIM_QUERIES)
|
| 223 |
+
def test_agent_simulation(api_client, message):
|
| 224 |
+
res = api_client.post("/agent/message", json={"tenant_id": "demo", "message": message})
|
| 225 |
+
data = res.json()
|
| 226 |
+
|
| 227 |
+
assert data["text"]
|
| 228 |
+
assert data["reasoning_trace"]
|
| 229 |
+
|
| 230 |
+
if "delete" in message.lower():
|
| 231 |
+
assert data["decision"]["action"] in {"block", "multi_step"}
|
| 232 |
+
reason = (data["decision"]["reason"] or "").lower()
|
| 233 |
+
assert "admin" in reason or "redflag" in reason
|