Spaces:
Running
Running
Pawan Mane commited on
Commit Β·
ceb563c
1
Parent(s): d80f659
Code optimization
Browse files- app/frontend/gradio_app.py +10 -16
- app/frontend/gradio_app_hf.py +31 -42
- app/nodes/llm_node.py +17 -9
- app/nodes/output.py +34 -7
- app/nodes/safety.py +37 -37
app/frontend/gradio_app.py
CHANGED
|
@@ -14,25 +14,26 @@ from app.nodes.hitl import HITLPauseException
|
|
| 14 |
|
| 15 |
_graph = build_graph()
|
| 16 |
_thread_config = {"configurable": {"thread_id": "gradio-session-001"}}
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
_pending_hitl_state: AgentState | None = None
|
| 19 |
|
| 20 |
|
| 21 |
def run_graph(query: str) -> AgentState:
|
| 22 |
-
|
| 23 |
-
_conversation_history.append(HumanMessage(content=query))
|
| 24 |
initial_state: AgentState = {
|
| 25 |
-
"messages":
|
|
|
|
| 26 |
"route": "", "rag_context": "", "tool_calls": [], "tool_results": [],
|
| 27 |
"response": "", "retry_count": 0, "hitl_approved": False,
|
| 28 |
-
"evaluation_score": 0.0, "guardrail_passed": True,
|
| 29 |
"memory_summary": "", "node_log": [],
|
| 30 |
}
|
| 31 |
return _graph.invoke(initial_state, config=_thread_config)
|
| 32 |
|
| 33 |
|
| 34 |
def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
| 35 |
-
global _conversation_history
|
| 36 |
from app.nodes.evaluation import evaluation_node, eval_route
|
| 37 |
from app.nodes.guardrails import guardrails_node
|
| 38 |
from app.nodes.output import output_node
|
|
@@ -44,7 +45,6 @@ def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
|
| 44 |
s = llm_node(s)
|
| 45 |
s = guardrails_node(s)
|
| 46 |
s = output_node(s)
|
| 47 |
-
_conversation_history = s["messages"]
|
| 48 |
return s
|
| 49 |
|
| 50 |
|
|
@@ -81,13 +81,7 @@ def handle_submit(user_message, chat_history):
|
|
| 81 |
score = fs.get("evaluation_score", 0.0)
|
| 82 |
g_ok = fs.get("guardrail_passed", True)
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
# doesn't poison the memory summary for future innocent queries
|
| 86 |
-
if not g_ok:
|
| 87 |
-
global _conversation_history
|
| 88 |
-
if _conversation_history:
|
| 89 |
-
_conversation_history.pop()
|
| 90 |
-
|
| 91 |
chat_history = chat_history + [bot_msg(fs.get("response", ""))]
|
| 92 |
meta = f"**Route:** {route.upper() or 'β'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'β
Passed' if g_ok else 'π« Blocked'}"
|
| 93 |
return (chat_history, "", format_trace(fs.get("node_log", [])),
|
|
@@ -130,8 +124,8 @@ def handle_reject(chat_history):
|
|
| 130 |
|
| 131 |
|
| 132 |
def handle_clear():
|
| 133 |
-
global
|
| 134 |
-
|
| 135 |
return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
|
| 136 |
|
| 137 |
|
|
|
|
| 14 |
|
| 15 |
_graph = build_graph()
|
| 16 |
_thread_config = {"configurable": {"thread_id": "gradio-session-001"}}
|
| 17 |
+
# Frontend holds NO conversation history.
|
| 18 |
+
# All message history is managed inside the graph via output_node.
|
| 19 |
+
# LangGraph MemorySaver persists state across invocations automatically.
|
| 20 |
_pending_hitl_state: AgentState | None = None
|
| 21 |
|
| 22 |
|
| 23 |
def run_graph(query: str) -> AgentState:
|
| 24 |
+
# Just pass the query β graph manages its own message history via state
|
|
|
|
| 25 |
initial_state: AgentState = {
|
| 26 |
+
"messages": [], # MemorySaver restores history; safety_node adds HumanMessage
|
| 27 |
+
"query": query,
|
| 28 |
"route": "", "rag_context": "", "tool_calls": [], "tool_results": [],
|
| 29 |
"response": "", "retry_count": 0, "hitl_approved": False,
|
| 30 |
+
"evaluation_score": 0.0, "guardrail_passed": True, "is_harmful": False,
|
| 31 |
"memory_summary": "", "node_log": [],
|
| 32 |
}
|
| 33 |
return _graph.invoke(initial_state, config=_thread_config)
|
| 34 |
|
| 35 |
|
| 36 |
def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
|
|
|
| 37 |
from app.nodes.evaluation import evaluation_node, eval_route
|
| 38 |
from app.nodes.guardrails import guardrails_node
|
| 39 |
from app.nodes.output import output_node
|
|
|
|
| 45 |
s = llm_node(s)
|
| 46 |
s = guardrails_node(s)
|
| 47 |
s = output_node(s)
|
|
|
|
| 48 |
return s
|
| 49 |
|
| 50 |
|
|
|
|
| 81 |
score = fs.get("evaluation_score", 0.0)
|
| 82 |
g_ok = fs.get("guardrail_passed", True)
|
| 83 |
|
| 84 |
+
# History is managed entirely by output_node inside the graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
chat_history = chat_history + [bot_msg(fs.get("response", ""))]
|
| 86 |
meta = f"**Route:** {route.upper() or 'β'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'β
Passed' if g_ok else 'π« Blocked'}"
|
| 87 |
return (chat_history, "", format_trace(fs.get("node_log", [])),
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def handle_clear():
|
| 127 |
+
global _pending_hitl_state
|
| 128 |
+
_pending_hitl_state = None
|
| 129 |
return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
|
| 130 |
|
| 131 |
|
app/frontend/gradio_app_hf.py
CHANGED
|
@@ -1,34 +1,29 @@
|
|
| 1 |
"""
|
| 2 |
app/frontend/gradio_app_hf.py
|
| 3 |
ββββββββββββββββββββββββββββββ
|
| 4 |
-
HuggingFace Spaces entry point.
|
| 5 |
|
| 6 |
-
Key differences from
|
| 7 |
- Reads all config from environment variables (HF injects secrets as env vars)
|
| 8 |
-
- No .env file
|
| 9 |
-
-
|
| 10 |
-
- PYTHONPATH=/app
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
import os
|
| 14 |
|
| 15 |
-
|
| 16 |
-
os.environ["
|
| 17 |
-
os.environ["PYTHONPATH"] = "/app"
|
| 18 |
-
|
| 19 |
-
# HITL defaults to false on public spaces β override via HF Space Variables
|
| 20 |
-
# All other secrets (GROQ_API_KEY, WEATHER_API_KEY, LLM_MODEL etc.)
|
| 21 |
-
# are set in HuggingFace Space β Settings β Variables and Secrets
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
# app/config.py calls load_dotenv() which would print a warning if .env
|
| 25 |
-
# is missing. We patch it to a no-op before config is imported.
|
| 26 |
import sys
|
| 27 |
from unittest.mock import MagicMock
|
| 28 |
if "dotenv" not in sys.modules:
|
| 29 |
sys.modules["dotenv"] = MagicMock()
|
| 30 |
|
| 31 |
-
# ββ Import the full app (config, graph, nodes all load here) βββββββββββββββ
|
| 32 |
import gradio as gr
|
| 33 |
from langchain_core.messages import HumanMessage
|
| 34 |
|
|
@@ -39,19 +34,18 @@ from app.frontend.css import CSS
|
|
| 39 |
|
| 40 |
|
| 41 |
# ββ Graph singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
-
_graph
|
| 43 |
_thread_config = {"configurable": {"thread_id": "hf-session-001"}}
|
| 44 |
-
_conversation_history
|
| 45 |
_pending_hitl_state: AgentState | None = None
|
| 46 |
|
| 47 |
|
| 48 |
# ββ Core runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
|
| 50 |
def run_graph(query: str) -> AgentState:
|
| 51 |
-
|
| 52 |
-
_conversation_history.append(HumanMessage(content=query))
|
| 53 |
initial_state: AgentState = {
|
| 54 |
-
"messages":
|
| 55 |
"query": query,
|
| 56 |
"route": "",
|
| 57 |
"rag_context": "",
|
|
@@ -62,15 +56,14 @@ def run_graph(query: str) -> AgentState:
|
|
| 62 |
"hitl_approved": False,
|
| 63 |
"evaluation_score": 0.0,
|
| 64 |
"guardrail_passed": True,
|
|
|
|
| 65 |
"memory_summary": "",
|
| 66 |
"node_log": [],
|
| 67 |
-
"is_harmful": False,
|
| 68 |
}
|
| 69 |
return _graph.invoke(initial_state, config=_thread_config)
|
| 70 |
|
| 71 |
|
| 72 |
def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
| 73 |
-
global _conversation_history
|
| 74 |
from app.nodes.evaluation import evaluation_node, eval_route
|
| 75 |
from app.nodes.guardrails import guardrails_node
|
| 76 |
from app.nodes.output import output_node
|
|
@@ -82,7 +75,6 @@ def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
|
| 82 |
s = llm_node(s)
|
| 83 |
s = guardrails_node(s)
|
| 84 |
s = output_node(s)
|
| 85 |
-
_conversation_history = s["messages"]
|
| 86 |
return s
|
| 87 |
|
| 88 |
|
|
@@ -112,7 +104,7 @@ def bot_msg(t): return {"role": "assistant", "content": t}
|
|
| 112 |
# ββ Event handlers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 113 |
|
| 114 |
def handle_submit(user_message, chat_history):
|
| 115 |
-
global _pending_hitl_state
|
| 116 |
if not user_message.strip():
|
| 117 |
return chat_history, "", "*Waiting for a query...*", "", gr.update(visible=False), gr.update(value="")
|
| 118 |
|
|
@@ -123,10 +115,7 @@ def handle_submit(user_message, chat_history):
|
|
| 123 |
score = fs.get("evaluation_score", 0.0)
|
| 124 |
g_ok = fs.get("guardrail_passed", True)
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
if not g_ok and _conversation_history:
|
| 128 |
-
_conversation_history.pop()
|
| 129 |
-
|
| 130 |
chat_history = chat_history + [bot_msg(fs.get("response", ""))]
|
| 131 |
meta = f"**Route:** {route.upper() or 'β'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'β
Passed' if g_ok else 'π« Blocked'}"
|
| 132 |
return (chat_history, "", format_trace(fs.get("node_log", [])),
|
|
@@ -169,8 +158,8 @@ def handle_reject(chat_history):
|
|
| 169 |
|
| 170 |
|
| 171 |
def handle_clear():
|
| 172 |
-
global
|
| 173 |
-
|
| 174 |
return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
|
| 175 |
|
| 176 |
|
|
@@ -183,7 +172,6 @@ def build_ui():
|
|
| 183 |
|
| 184 |
with gr.Row(equal_height=True):
|
| 185 |
|
| 186 |
-
# ββ Main chat βββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
with gr.Column(scale=4):
|
| 188 |
|
| 189 |
with gr.Group(elem_classes="section-box"):
|
|
@@ -222,7 +210,6 @@ def build_ui():
|
|
| 222 |
label="Examples",
|
| 223 |
)
|
| 224 |
|
| 225 |
-
# ββ Right sidebar ββββββββββββββββββββββββββββββββββββββββββ
|
| 226 |
with gr.Column(scale=1):
|
| 227 |
|
| 228 |
with gr.Group(elem_classes="section-box"):
|
|
@@ -232,15 +219,17 @@ def build_ui():
|
|
| 232 |
with gr.Group(elem_classes="section-box"):
|
| 233 |
gr.Markdown("""**πΊ Graph Topology**
|
| 234 |
```
|
| 235 |
-
START β
|
| 236 |
-
ββ
|
| 237 |
-
ββ
|
| 238 |
-
ββ
|
| 239 |
-
ββ
|
| 240 |
-
ββ
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
```""")
|
| 245 |
|
| 246 |
submit_outs = [chatbot, user_input, trace_display, meta_display, hitl_panel, hitl_content]
|
|
|
|
| 1 |
"""
|
| 2 |
app/frontend/gradio_app_hf.py
|
| 3 |
ββββββββββββββββββββββββββββββ
|
| 4 |
+
HuggingFace Spaces entry point β fully synced with gradio_app.py.
|
| 5 |
|
| 6 |
+
Key differences from gradio_app.py:
|
| 7 |
- Reads all config from environment variables (HF injects secrets as env vars)
|
| 8 |
+
- No .env file β dotenv silenced gracefully
|
| 9 |
+
- Port 7860 (HF Spaces requirement)
|
| 10 |
+
- PYTHONPATH=/app set in Dockerfile
|
| 11 |
+
|
| 12 |
+
History management: entirely inside the graph (output_node + MemorySaver).
|
| 13 |
+
Frontend is stateless β no _conversation_history here.
|
| 14 |
"""
|
| 15 |
|
| 16 |
import os
|
| 17 |
|
| 18 |
+
os.environ["GRADIO_MODE"] = "true"
|
| 19 |
+
os.environ["PYTHONPATH"] = "/app"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Silence dotenv β no .env on HF Spaces
|
|
|
|
|
|
|
| 22 |
import sys
|
| 23 |
from unittest.mock import MagicMock
|
| 24 |
if "dotenv" not in sys.modules:
|
| 25 |
sys.modules["dotenv"] = MagicMock()
|
| 26 |
|
|
|
|
| 27 |
import gradio as gr
|
| 28 |
from langchain_core.messages import HumanMessage
|
| 29 |
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
# ββ Graph singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
_graph = build_graph()
|
| 38 |
_thread_config = {"configurable": {"thread_id": "hf-session-001"}}
|
| 39 |
+
# No _conversation_history β graph manages all history via output_node + MemorySaver
|
| 40 |
_pending_hitl_state: AgentState | None = None
|
| 41 |
|
| 42 |
|
| 43 |
# ββ Core runner ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 44 |
|
| 45 |
def run_graph(query: str) -> AgentState:
|
| 46 |
+
# messages=[] β MemorySaver restores prior history; safety_node adds HumanMessage
|
|
|
|
| 47 |
initial_state: AgentState = {
|
| 48 |
+
"messages": [],
|
| 49 |
"query": query,
|
| 50 |
"route": "",
|
| 51 |
"rag_context": "",
|
|
|
|
| 56 |
"hitl_approved": False,
|
| 57 |
"evaluation_score": 0.0,
|
| 58 |
"guardrail_passed": True,
|
| 59 |
+
"is_harmful": False,
|
| 60 |
"memory_summary": "",
|
| 61 |
"node_log": [],
|
|
|
|
| 62 |
}
|
| 63 |
return _graph.invoke(initial_state, config=_thread_config)
|
| 64 |
|
| 65 |
|
| 66 |
def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
|
|
|
|
| 67 |
from app.nodes.evaluation import evaluation_node, eval_route
|
| 68 |
from app.nodes.guardrails import guardrails_node
|
| 69 |
from app.nodes.output import output_node
|
|
|
|
| 75 |
s = llm_node(s)
|
| 76 |
s = guardrails_node(s)
|
| 77 |
s = output_node(s)
|
|
|
|
| 78 |
return s
|
| 79 |
|
| 80 |
|
|
|
|
| 104 |
# ββ Event handlers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
|
| 106 |
def handle_submit(user_message, chat_history):
|
| 107 |
+
global _pending_hitl_state
|
| 108 |
if not user_message.strip():
|
| 109 |
return chat_history, "", "*Waiting for a query...*", "", gr.update(visible=False), gr.update(value="")
|
| 110 |
|
|
|
|
| 115 |
score = fs.get("evaluation_score", 0.0)
|
| 116 |
g_ok = fs.get("guardrail_passed", True)
|
| 117 |
|
| 118 |
+
# History managed entirely by output_node inside the graph
|
|
|
|
|
|
|
|
|
|
| 119 |
chat_history = chat_history + [bot_msg(fs.get("response", ""))]
|
| 120 |
meta = f"**Route:** {route.upper() or 'β'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'β
Passed' if g_ok else 'π« Blocked'}"
|
| 121 |
return (chat_history, "", format_trace(fs.get("node_log", [])),
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
def handle_clear():
|
| 161 |
+
global _pending_hitl_state
|
| 162 |
+
_pending_hitl_state = None
|
| 163 |
return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
|
| 164 |
|
| 165 |
|
|
|
|
| 172 |
|
| 173 |
with gr.Row(equal_height=True):
|
| 174 |
|
|
|
|
| 175 |
with gr.Column(scale=4):
|
| 176 |
|
| 177 |
with gr.Group(elem_classes="section-box"):
|
|
|
|
| 210 |
label="Examples",
|
| 211 |
)
|
| 212 |
|
|
|
|
| 213 |
with gr.Column(scale=1):
|
| 214 |
|
| 215 |
with gr.Group(elem_classes="section-box"):
|
|
|
|
| 219 |
with gr.Group(elem_classes="section-box"):
|
| 220 |
gr.Markdown("""**πΊ Graph Topology**
|
| 221 |
```
|
| 222 |
+
START β safety
|
| 223 |
+
ββ blocked β output β END
|
| 224 |
+
ββ continue β router
|
| 225 |
+
ββ rag β llm
|
| 226 |
+
ββ tool/general β llm
|
| 227 |
+
ββ tool_executor
|
| 228 |
+
ββ memory β hitl
|
| 229 |
+
ββ evaluation
|
| 230 |
+
β ββ retry β llm
|
| 231 |
+
β ββ guardrails β output
|
| 232 |
+
ββ END
|
| 233 |
```""")
|
| 234 |
|
| 235 |
submit_outs = [chatbot, user_input, trace_display, meta_display, hitl_panel, hitl_content]
|
app/nodes/llm_node.py
CHANGED
|
@@ -36,25 +36,33 @@ def llm_node(state: AgentState) -> AgentState:
|
|
| 36 |
try:
|
| 37 |
# Build system prompt
|
| 38 |
system_parts = [
|
| 39 |
-
"You are a helpful AI assistant.
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
]
|
| 43 |
if state.get("rag_context"):
|
| 44 |
system_parts.append(f"\nUse the following context to answer:\n{state['rag_context']}")
|
| 45 |
if state.get("memory_summary"):
|
| 46 |
-
system_parts.append(f"\
|
| 47 |
|
| 48 |
system_msg = SystemMessage(content="\n".join(system_parts))
|
| 49 |
|
| 50 |
-
#
|
| 51 |
-
#
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
if state["route"] == "tool":
|
| 55 |
-
|
|
|
|
| 56 |
else:
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
tool_calls = getattr(ai_msg, "tool_calls", []) or []
|
| 60 |
response_text = ai_msg.content or ""
|
|
|
|
| 36 |
try:
|
| 37 |
# Build system prompt
|
| 38 |
system_parts = [
|
| 39 |
+
"You are a helpful AI assistant.",
|
| 40 |
+
"Answer the current query using the conversation history for context.",
|
| 41 |
+
"Keep responses concise and relevant.",
|
| 42 |
]
|
| 43 |
if state.get("rag_context"):
|
| 44 |
system_parts.append(f"\nUse the following context to answer:\n{state['rag_context']}")
|
| 45 |
if state.get("memory_summary"):
|
| 46 |
+
system_parts.append(f"\nConversation summary so far:\n{state['memory_summary']}")
|
| 47 |
|
| 48 |
system_msg = SystemMessage(content="\n".join(system_parts))
|
| 49 |
|
| 50 |
+
# state["messages"] = prior safe history (from MemorySaver) + current HumanMessage
|
| 51 |
+
# Scrub tool noise, then build: [system, h1, a1, h2, a2, ..., current_query]
|
| 52 |
+
from langchain_core.messages import ToolMessage, AIMessage as AI
|
| 53 |
+
clean = [
|
| 54 |
+
m for m in state["messages"]
|
| 55 |
+
if not isinstance(m, ToolMessage)
|
| 56 |
+
and not (isinstance(m, AI) and getattr(m, "tool_calls", []))
|
| 57 |
+
]
|
| 58 |
+
messages = [system_msg] + clean
|
| 59 |
|
| 60 |
if state["route"] == "tool":
|
| 61 |
+
# Tool route: only current query to avoid re-firing old tool calls
|
| 62 |
+
ai_msg = _llm_with_tools.invoke([system_msg, HumanMessage(content=state["query"])])
|
| 63 |
else:
|
| 64 |
+
# RAG / general: full clean history for context
|
| 65 |
+
ai_msg = llm.invoke(messages)
|
| 66 |
|
| 67 |
tool_calls = getattr(ai_msg, "tool_calls", []) or []
|
| 68 |
response_text = ai_msg.content or ""
|
app/nodes/output.py
CHANGED
|
@@ -1,11 +1,38 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from app.state import AgentState
|
| 4 |
|
| 5 |
|
| 6 |
def output_node(state: AgentState) -> AgentState:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app/nodes/output.py
|
| 3 |
+
ββββββββββββββββββββ
|
| 4 |
+
Single source of truth for message history.
|
| 5 |
+
|
| 6 |
+
Flow per turn:
|
| 7 |
+
gradio sends: messages=[] (empty β MemorySaver restores checkpoint history)
|
| 8 |
+
safety adds: HumanMessage(query) to messages
|
| 9 |
+
output_node:
|
| 10 |
+
- harmful/blocked β drop the HumanMessage, keep prior history clean
|
| 11 |
+
- safe β keep HumanMessage + append AIMessage(response)
|
| 12 |
+
|
| 13 |
+
MemorySaver then persists the updated messages for next turn.
|
| 14 |
+
"""
|
| 15 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
| 16 |
from app.state import AgentState
|
| 17 |
|
| 18 |
|
| 19 |
def output_node(state: AgentState) -> AgentState:
|
| 20 |
+
log = state.get("node_log", []) + ["output"]
|
| 21 |
+
response = state["response"]
|
| 22 |
+
messages = list(state["messages"])
|
| 23 |
+
is_harmful = state.get("is_harmful", False)
|
| 24 |
+
guardrail_ok = state.get("guardrail_passed", True)
|
| 25 |
+
|
| 26 |
+
if is_harmful or not guardrail_ok:
|
| 27 |
+
# Drop the HumanMessage for this turn β never pollute history
|
| 28 |
+
messages = [m for m in messages
|
| 29 |
+
if not (isinstance(m, HumanMessage) and m.content == state["query"])]
|
| 30 |
+
print(f"\nπ€ {response}\n")
|
| 31 |
+
print("[OUTPUT] Harmful turn scrubbed from history.")
|
| 32 |
+
else:
|
| 33 |
+
# Safe β HumanMessage already in messages (added by safety_node)
|
| 34 |
+
# Just append the assistant response
|
| 35 |
+
messages = messages + [AIMessage(content=response)]
|
| 36 |
+
print(f"\nπ€ {response}\n")
|
| 37 |
+
|
| 38 |
+
return {**state, "messages": messages, "node_log": log}
|
app/nodes/safety.py
CHANGED
|
@@ -1,20 +1,12 @@
|
|
| 1 |
"""
|
| 2 |
app/nodes/safety.py
|
| 3 |
ββββββββββββββββββββ
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
Pattern used by: Anthropic, Cohere, Mistral AI
|
| 10 |
-
β System prompt + structured classification prompt
|
| 11 |
-
β Confidence threshold to avoid false positives
|
| 12 |
-
β Fail-open on errors (don't block legitimate users)
|
| 13 |
-
|
| 14 |
-
Two checks:
|
| 15 |
-
1. Query intent β is the user asking for something harmful?
|
| 16 |
-
2. Response check β did the LLM accidentally produce harmful output?
|
| 17 |
-
(guardrails.py handles this after response is generated)
|
| 18 |
"""
|
| 19 |
import json
|
| 20 |
from langchain_core.messages import SystemMessage, HumanMessage
|
|
@@ -23,7 +15,7 @@ from app.utils.llm import llm
|
|
| 23 |
|
| 24 |
SAFE_FALLBACK = "I'm sorry, I can't help with that request."
|
| 25 |
|
| 26 |
-
SYSTEM_PROMPT = """You are a content safety classifier.
|
| 27 |
Your only job is to analyze user queries and determine if they request harmful content.
|
| 28 |
Be precise β do not over-block. Only flag clearly harmful requests."""
|
| 29 |
|
|
@@ -32,7 +24,7 @@ CLASSIFICATION_PROMPT = """Analyze this query and classify it.
|
|
| 32 |
HARMFUL categories (flag these):
|
| 33 |
- Violence: instructions to physically harm, kill, or assault people
|
| 34 |
- Unauthorized access: hacking, stealing credentials, bypassing security
|
| 35 |
-
- Weapons: making bombs, explosives, illegal weapons
|
| 36 |
- Illegal substances: synthesizing or manufacturing drugs
|
| 37 |
- Harassment: stalking, doxxing, non-consensual tracking
|
| 38 |
- Self-harm: methods to hurt oneself or suicide instructions
|
|
@@ -41,27 +33,35 @@ HARMFUL categories (flag these):
|
|
| 41 |
NOT harmful (do not flag these):
|
| 42 |
- Competitive contexts: "beat my friend at chess", "win an argument"
|
| 43 |
- Security education: conceptual explanations of how attacks work
|
| 44 |
-
- Fiction
|
| 45 |
-
- Legitimate anger: "I'm so frustrated I could scream"
|
| 46 |
- Medical: drug interactions, symptoms, treatments
|
| 47 |
- History/news: discussing past violent events
|
| 48 |
|
| 49 |
Query: "{query}"
|
| 50 |
|
| 51 |
-
|
| 52 |
-
{{"harmful": true/false, "category": "violence|hacking|weapons|drugs|harassment|self_harm|hate|safe", "confidence": 0.0-1.0, "reason": "one sentence"}}"""
|
| 53 |
|
| 54 |
|
| 55 |
def safety_node(state: AgentState) -> AgentState:
|
| 56 |
-
query
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
try:
|
| 60 |
response = llm.invoke([
|
| 61 |
SystemMessage(content=SYSTEM_PROMPT),
|
| 62 |
HumanMessage(content=CLASSIFICATION_PROMPT.format(query=query)),
|
| 63 |
])
|
| 64 |
-
|
| 65 |
raw = response.content.strip().removeprefix("```json").removesuffix("```").strip()
|
| 66 |
result = json.loads(raw)
|
| 67 |
|
|
@@ -70,18 +70,11 @@ def safety_node(state: AgentState) -> AgentState:
|
|
| 70 |
category = result.get("category", "safe")
|
| 71 |
reason = result.get("reason", "")
|
| 72 |
|
| 73 |
-
# IST timestamp for every query
|
| 74 |
-
from datetime import datetime, timezone, timedelta
|
| 75 |
-
IST = timezone(timedelta(hours=5, minutes=30))
|
| 76 |
-
ts = datetime.now(IST).strftime("%d %b %Y %I:%M:%S %p IST")
|
| 77 |
-
print(f"[{ts}] [User Query] β {query}")
|
| 78 |
-
|
| 79 |
-
# Require high confidence to avoid false positives on edge cases
|
| 80 |
-
# e.g. "how to beat someone at chess" should NOT be blocked
|
| 81 |
if harmful and confidence >= 0.85:
|
| 82 |
print(f"[SAFETY] π« Blocked β {category} ({confidence:.0%}): {reason}")
|
| 83 |
return {
|
| 84 |
**state,
|
|
|
|
| 85 |
"is_harmful": True,
|
| 86 |
"guardrail_passed": False,
|
| 87 |
"response": SAFE_FALLBACK,
|
|
@@ -89,17 +82,24 @@ def safety_node(state: AgentState) -> AgentState:
|
|
| 89 |
}
|
| 90 |
|
| 91 |
if harmful and confidence < 0.85:
|
| 92 |
-
|
| 93 |
-
print(f"[SAFETY] β οΈ Ambiguous ({category} {confidence:.0%}) β allowing through: {reason}")
|
| 94 |
|
| 95 |
print(f"[SAFETY] β
Safe β {reason}")
|
| 96 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
except Exception as e:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
def safety_route(state: AgentState) -> str:
|
|
|
|
| 1 |
"""
|
| 2 |
app/nodes/safety.py
|
| 3 |
ββββββββββββββββββββ
|
| 4 |
+
FIRST node in graph. Does two things:
|
| 5 |
+
1. Adds HumanMessage(query) to messages β so history is always updated here
|
| 6 |
+
2. Classifies query intent β harmful queries short-circuit to output_node
|
| 7 |
|
| 8 |
+
output_node will scrub the HumanMessage if harmful.
|
| 9 |
+
output_node will keep it + add AIMessage if safe.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
import json
|
| 12 |
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
| 15 |
|
| 16 |
SAFE_FALLBACK = "I'm sorry, I can't help with that request."
|
| 17 |
|
| 18 |
+
SYSTEM_PROMPT = """You are a content safety classifier.
|
| 19 |
Your only job is to analyze user queries and determine if they request harmful content.
|
| 20 |
Be precise β do not over-block. Only flag clearly harmful requests."""
|
| 21 |
|
|
|
|
| 24 |
HARMFUL categories (flag these):
|
| 25 |
- Violence: instructions to physically harm, kill, or assault people
|
| 26 |
- Unauthorized access: hacking, stealing credentials, bypassing security
|
| 27 |
+
- Weapons: making bombs, explosives, illegal weapons
|
| 28 |
- Illegal substances: synthesizing or manufacturing drugs
|
| 29 |
- Harassment: stalking, doxxing, non-consensual tracking
|
| 30 |
- Self-harm: methods to hurt oneself or suicide instructions
|
|
|
|
| 33 |
NOT harmful (do not flag these):
|
| 34 |
- Competitive contexts: "beat my friend at chess", "win an argument"
|
| 35 |
- Security education: conceptual explanations of how attacks work
|
| 36 |
+
- Fiction/creative writing with dark themes
|
|
|
|
| 37 |
- Medical: drug interactions, symptoms, treatments
|
| 38 |
- History/news: discussing past violent events
|
| 39 |
|
| 40 |
Query: "{query}"
|
| 41 |
|
| 42 |
+
JSON only: {{"harmful": true/false, "category": "violence|hacking|weapons|drugs|harassment|self_harm|hate|safe", "confidence": 0.0-1.0, "reason": "one sentence"}}"""
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def safety_node(state: AgentState) -> AgentState:
|
| 46 |
+
query = state.get("query", "")
|
| 47 |
+
messages = list(state.get("messages", []))
|
| 48 |
+
log = state.get("node_log", [])
|
| 49 |
+
|
| 50 |
+
# ββ Add HumanMessage to history first ββββββββββββββββββββββββββββββββ
|
| 51 |
+
# output_node will scrub it if harmful, keep it if safe
|
| 52 |
+
messages = messages + [HumanMessage(content=query)]
|
| 53 |
+
|
| 54 |
+
# ββ IST timestamp βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
from datetime import datetime, timezone, timedelta
|
| 56 |
+
IST = timezone(timedelta(hours=5, minutes=30))
|
| 57 |
+
ts = datetime.now(IST).strftime("%d %b %Y %I:%M:%S %p IST")
|
| 58 |
+
print(f"[{ts}] [User Query] β {query}")
|
| 59 |
|
| 60 |
try:
|
| 61 |
response = llm.invoke([
|
| 62 |
SystemMessage(content=SYSTEM_PROMPT),
|
| 63 |
HumanMessage(content=CLASSIFICATION_PROMPT.format(query=query)),
|
| 64 |
])
|
|
|
|
| 65 |
raw = response.content.strip().removeprefix("```json").removesuffix("```").strip()
|
| 66 |
result = json.loads(raw)
|
| 67 |
|
|
|
|
| 70 |
category = result.get("category", "safe")
|
| 71 |
reason = result.get("reason", "")
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if harmful and confidence >= 0.85:
|
| 74 |
print(f"[SAFETY] π« Blocked β {category} ({confidence:.0%}): {reason}")
|
| 75 |
return {
|
| 76 |
**state,
|
| 77 |
+
"messages": messages, # HumanMessage included β output_node will scrub
|
| 78 |
"is_harmful": True,
|
| 79 |
"guardrail_passed": False,
|
| 80 |
"response": SAFE_FALLBACK,
|
|
|
|
| 82 |
}
|
| 83 |
|
| 84 |
if harmful and confidence < 0.85:
|
| 85 |
+
print(f"[SAFETY] β οΈ Ambiguous ({category} {confidence:.0%}) β allowing: {reason}")
|
|
|
|
| 86 |
|
| 87 |
print(f"[SAFETY] β
Safe β {reason}")
|
| 88 |
+
return {
|
| 89 |
+
**state,
|
| 90 |
+
"messages": messages,
|
| 91 |
+
"is_harmful": False,
|
| 92 |
+
"node_log": log + ["safety β
"],
|
| 93 |
+
}
|
| 94 |
|
| 95 |
except Exception as e:
|
| 96 |
+
print(f"[SAFETY] Classifier error ({e}) β fail-open")
|
| 97 |
+
return {
|
| 98 |
+
**state,
|
| 99 |
+
"messages": messages,
|
| 100 |
+
"is_harmful": False,
|
| 101 |
+
"node_log": log + ["safety (errorβallowed)"],
|
| 102 |
+
}
|
| 103 |
|
| 104 |
|
| 105 |
def safety_route(state: AgentState) -> str:
|