Pawan Mane commited on
Commit
ceb563c
Β·
1 Parent(s): d80f659

Code optimization

Browse files
app/frontend/gradio_app.py CHANGED
@@ -14,25 +14,26 @@ from app.nodes.hitl import HITLPauseException
14
 
15
  _graph = build_graph()
16
  _thread_config = {"configurable": {"thread_id": "gradio-session-001"}}
17
- _conversation_history = []
 
 
18
  _pending_hitl_state: AgentState | None = None
19
 
20
 
21
  def run_graph(query: str) -> AgentState:
22
- global _conversation_history
23
- _conversation_history.append(HumanMessage(content=query))
24
  initial_state: AgentState = {
25
- "messages": _conversation_history.copy(), "query": query,
 
26
  "route": "", "rag_context": "", "tool_calls": [], "tool_results": [],
27
  "response": "", "retry_count": 0, "hitl_approved": False,
28
- "evaluation_score": 0.0, "guardrail_passed": True,
29
  "memory_summary": "", "node_log": [],
30
  }
31
  return _graph.invoke(initial_state, config=_thread_config)
32
 
33
 
34
  def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
35
- global _conversation_history
36
  from app.nodes.evaluation import evaluation_node, eval_route
37
  from app.nodes.guardrails import guardrails_node
38
  from app.nodes.output import output_node
@@ -44,7 +45,6 @@ def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
44
  s = llm_node(s)
45
  s = guardrails_node(s)
46
  s = output_node(s)
47
- _conversation_history = s["messages"]
48
  return s
49
 
50
 
@@ -81,13 +81,7 @@ def handle_submit(user_message, chat_history):
81
  score = fs.get("evaluation_score", 0.0)
82
  g_ok = fs.get("guardrail_passed", True)
83
 
84
- # Guardrail blocked β€” remove this exchange from history so it
85
- # doesn't poison the memory summary for future innocent queries
86
- if not g_ok:
87
- global _conversation_history
88
- if _conversation_history:
89
- _conversation_history.pop()
90
-
91
  chat_history = chat_history + [bot_msg(fs.get("response", ""))]
92
  meta = f"**Route:** {route.upper() or 'β€”'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'βœ… Passed' if g_ok else '🚫 Blocked'}"
93
  return (chat_history, "", format_trace(fs.get("node_log", [])),
@@ -130,8 +124,8 @@ def handle_reject(chat_history):
130
 
131
 
132
  def handle_clear():
133
- global _conversation_history, _pending_hitl_state
134
- _conversation_history, _pending_hitl_state = [], None
135
  return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
136
 
137
 
 
14
 
15
  _graph = build_graph()
16
  _thread_config = {"configurable": {"thread_id": "gradio-session-001"}}
17
+ # Frontend holds NO conversation history.
18
+ # All message history is managed inside the graph via output_node.
19
+ # LangGraph MemorySaver persists state across invocations automatically.
20
  _pending_hitl_state: AgentState | None = None
21
 
22
 
23
  def run_graph(query: str) -> AgentState:
24
+ # Just pass the query β€” graph manages its own message history via state
 
25
  initial_state: AgentState = {
26
+ "messages": [], # MemorySaver restores history; safety_node adds HumanMessage
27
+ "query": query,
28
  "route": "", "rag_context": "", "tool_calls": [], "tool_results": [],
29
  "response": "", "retry_count": 0, "hitl_approved": False,
30
+ "evaluation_score": 0.0, "guardrail_passed": True, "is_harmful": False,
31
  "memory_summary": "", "node_log": [],
32
  }
33
  return _graph.invoke(initial_state, config=_thread_config)
34
 
35
 
36
  def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
 
37
  from app.nodes.evaluation import evaluation_node, eval_route
38
  from app.nodes.guardrails import guardrails_node
39
  from app.nodes.output import output_node
 
45
  s = llm_node(s)
46
  s = guardrails_node(s)
47
  s = output_node(s)
 
48
  return s
49
 
50
 
 
81
  score = fs.get("evaluation_score", 0.0)
82
  g_ok = fs.get("guardrail_passed", True)
83
 
84
+ # History is managed entirely by output_node inside the graph
 
 
 
 
 
 
85
  chat_history = chat_history + [bot_msg(fs.get("response", ""))]
86
  meta = f"**Route:** {route.upper() or 'β€”'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'βœ… Passed' if g_ok else '🚫 Blocked'}"
87
  return (chat_history, "", format_trace(fs.get("node_log", [])),
 
124
 
125
 
126
  def handle_clear():
127
+ global _pending_hitl_state
128
+ _pending_hitl_state = None
129
  return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
130
 
131
 
app/frontend/gradio_app_hf.py CHANGED
@@ -1,34 +1,29 @@
1
  """
2
  app/frontend/gradio_app_hf.py
3
  ──────────────────────────────
4
- HuggingFace Spaces entry point.
5
 
6
- Key differences from local gradio_app.py:
7
  - Reads all config from environment variables (HF injects secrets as env vars)
8
- - No .env file available on HF Spaces β€” dotenv is silenced gracefully
9
- - Runs on port 7860 (HF Spaces requirement)
10
- - PYTHONPATH=/app must be set in Dockerfile so `from app.*` imports resolve
 
 
 
11
  """
12
 
13
  import os
14
 
15
- # ── Set env flags before any app imports ──────────────────────────────────
16
- os.environ["GRADIO_MODE"] = "true"
17
- os.environ["PYTHONPATH"] = "/app"
18
-
19
- # HITL defaults to false on public spaces β€” override via HF Space Variables
20
- # All other secrets (GROQ_API_KEY, WEATHER_API_KEY, LLM_MODEL etc.)
21
- # are set in HuggingFace Space β†’ Settings β†’ Variables and Secrets
22
 
23
- # ── Silence dotenv β€” no .env file exists on HF Spaces ─────────────────────
24
- # app/config.py calls load_dotenv() which would print a warning if .env
25
- # is missing. We patch it to a no-op before config is imported.
26
  import sys
27
  from unittest.mock import MagicMock
28
  if "dotenv" not in sys.modules:
29
  sys.modules["dotenv"] = MagicMock()
30
 
31
- # ── Import the full app (config, graph, nodes all load here) ───────────────
32
  import gradio as gr
33
  from langchain_core.messages import HumanMessage
34
 
@@ -39,19 +34,18 @@ from app.frontend.css import CSS
39
 
40
 
41
  # ── Graph singleton ────────────────────────────────────────────────────────
42
- _graph = build_graph()
43
  _thread_config = {"configurable": {"thread_id": "hf-session-001"}}
44
- _conversation_history = []
45
  _pending_hitl_state: AgentState | None = None
46
 
47
 
48
  # ── Core runner ────────────────────────────────────────────────────────────
49
 
50
  def run_graph(query: str) -> AgentState:
51
- global _conversation_history
52
- _conversation_history.append(HumanMessage(content=query))
53
  initial_state: AgentState = {
54
- "messages": _conversation_history.copy(),
55
  "query": query,
56
  "route": "",
57
  "rag_context": "",
@@ -62,15 +56,14 @@ def run_graph(query: str) -> AgentState:
62
  "hitl_approved": False,
63
  "evaluation_score": 0.0,
64
  "guardrail_passed": True,
 
65
  "memory_summary": "",
66
  "node_log": [],
67
- "is_harmful": False,
68
  }
69
  return _graph.invoke(initial_state, config=_thread_config)
70
 
71
 
72
  def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
73
- global _conversation_history
74
  from app.nodes.evaluation import evaluation_node, eval_route
75
  from app.nodes.guardrails import guardrails_node
76
  from app.nodes.output import output_node
@@ -82,7 +75,6 @@ def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
82
  s = llm_node(s)
83
  s = guardrails_node(s)
84
  s = output_node(s)
85
- _conversation_history = s["messages"]
86
  return s
87
 
88
 
@@ -112,7 +104,7 @@ def bot_msg(t): return {"role": "assistant", "content": t}
112
  # ── Event handlers ─────────────────────────────────────────────────────────
113
 
114
  def handle_submit(user_message, chat_history):
115
- global _pending_hitl_state, _conversation_history
116
  if not user_message.strip():
117
  return chat_history, "", "*Waiting for a query...*", "", gr.update(visible=False), gr.update(value="")
118
 
@@ -123,10 +115,7 @@ def handle_submit(user_message, chat_history):
123
  score = fs.get("evaluation_score", 0.0)
124
  g_ok = fs.get("guardrail_passed", True)
125
 
126
- # Drop blocked exchange from history to prevent memory poisoning
127
- if not g_ok and _conversation_history:
128
- _conversation_history.pop()
129
-
130
  chat_history = chat_history + [bot_msg(fs.get("response", ""))]
131
  meta = f"**Route:** {route.upper() or 'β€”'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'βœ… Passed' if g_ok else '🚫 Blocked'}"
132
  return (chat_history, "", format_trace(fs.get("node_log", [])),
@@ -169,8 +158,8 @@ def handle_reject(chat_history):
169
 
170
 
171
  def handle_clear():
172
- global _conversation_history, _pending_hitl_state
173
- _conversation_history, _pending_hitl_state = [], None
174
  return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
175
 
176
 
@@ -183,7 +172,6 @@ def build_ui():
183
 
184
  with gr.Row(equal_height=True):
185
 
186
- # ══ Main chat ═════════════════════════════════════════════
187
  with gr.Column(scale=4):
188
 
189
  with gr.Group(elem_classes="section-box"):
@@ -222,7 +210,6 @@ def build_ui():
222
  label="Examples",
223
  )
224
 
225
- # ══ Right sidebar ══════════════════════════════════════════
226
  with gr.Column(scale=1):
227
 
228
  with gr.Group(elem_classes="section-box"):
@@ -232,15 +219,17 @@ def build_ui():
232
  with gr.Group(elem_classes="section-box"):
233
  gr.Markdown("""**πŸ—Ί Graph Topology**
234
  ```
235
- START β†’ router
236
- β”œβ”€ rag β†’ llm
237
- └─ tool/general β†’ llm
238
- β”œβ”€ tool_executor
239
- └─ memory β†’ hitl
240
- β”œβ”€ evaluation
241
- β”‚ β”œβ”€ retry β†’ llm
242
- β”‚ └─ guardrails β†’ output
243
- └─ END
 
 
244
  ```""")
245
 
246
  submit_outs = [chatbot, user_input, trace_display, meta_display, hitl_panel, hitl_content]
 
1
  """
2
  app/frontend/gradio_app_hf.py
3
  ──────────────────────────────
4
+ HuggingFace Spaces entry point β€” fully synced with gradio_app.py.
5
 
6
+ Key differences from gradio_app.py:
7
  - Reads all config from environment variables (HF injects secrets as env vars)
8
+ - No .env file β€” dotenv silenced gracefully
9
+ - Port 7860 (HF Spaces requirement)
10
+ - PYTHONPATH=/app set in Dockerfile
11
+
12
+ History management: entirely inside the graph (output_node + MemorySaver).
13
+ Frontend is stateless β€” no _conversation_history here.
14
  """
15
 
16
  import os
17
 
18
+ os.environ["GRADIO_MODE"] = "true"
19
+ os.environ["PYTHONPATH"] = "/app"
 
 
 
 
 
20
 
21
+ # Silence dotenv β€” no .env on HF Spaces
 
 
22
  import sys
23
  from unittest.mock import MagicMock
24
  if "dotenv" not in sys.modules:
25
  sys.modules["dotenv"] = MagicMock()
26
 
 
27
  import gradio as gr
28
  from langchain_core.messages import HumanMessage
29
 
 
34
 
35
 
36
  # ── Graph singleton ────────────────────────────────────────────────────────
37
+ _graph = build_graph()
38
  _thread_config = {"configurable": {"thread_id": "hf-session-001"}}
39
+ # No _conversation_history β€” graph manages all history via output_node + MemorySaver
40
  _pending_hitl_state: AgentState | None = None
41
 
42
 
43
  # ── Core runner ────────────────────────────────────────────────────────────
44
 
45
  def run_graph(query: str) -> AgentState:
46
+ # messages=[] β€” MemorySaver restores prior history; safety_node adds HumanMessage
 
47
  initial_state: AgentState = {
48
+ "messages": [],
49
  "query": query,
50
  "route": "",
51
  "rag_context": "",
 
56
  "hitl_approved": False,
57
  "evaluation_score": 0.0,
58
  "guardrail_passed": True,
59
+ "is_harmful": False,
60
  "memory_summary": "",
61
  "node_log": [],
 
62
  }
63
  return _graph.invoke(initial_state, config=_thread_config)
64
 
65
 
66
  def resume_graph_after_hitl(state: AgentState, approved: bool) -> AgentState:
 
67
  from app.nodes.evaluation import evaluation_node, eval_route
68
  from app.nodes.guardrails import guardrails_node
69
  from app.nodes.output import output_node
 
75
  s = llm_node(s)
76
  s = guardrails_node(s)
77
  s = output_node(s)
 
78
  return s
79
 
80
 
 
104
  # ── Event handlers ─────────────────────────────────────────────────────────
105
 
106
  def handle_submit(user_message, chat_history):
107
+ global _pending_hitl_state
108
  if not user_message.strip():
109
  return chat_history, "", "*Waiting for a query...*", "", gr.update(visible=False), gr.update(value="")
110
 
 
115
  score = fs.get("evaluation_score", 0.0)
116
  g_ok = fs.get("guardrail_passed", True)
117
 
118
+ # History managed entirely by output_node inside the graph
 
 
 
119
  chat_history = chat_history + [bot_msg(fs.get("response", ""))]
120
  meta = f"**Route:** {route.upper() or 'β€”'} Β· **Eval:** {score:.2f} Β· **Guardrail:** {'βœ… Passed' if g_ok else '🚫 Blocked'}"
121
  return (chat_history, "", format_trace(fs.get("node_log", [])),
 
158
 
159
 
160
  def handle_clear():
161
+ global _pending_hitl_state
162
+ _pending_hitl_state = None
163
  return [], "", "*Waiting for a query...*", "", gr.update(visible=False)
164
 
165
 
 
172
 
173
  with gr.Row(equal_height=True):
174
 
 
175
  with gr.Column(scale=4):
176
 
177
  with gr.Group(elem_classes="section-box"):
 
210
  label="Examples",
211
  )
212
 
 
213
  with gr.Column(scale=1):
214
 
215
  with gr.Group(elem_classes="section-box"):
 
219
  with gr.Group(elem_classes="section-box"):
220
  gr.Markdown("""**πŸ—Ί Graph Topology**
221
  ```
222
+ START β†’ safety
223
+ β”œβ”€ blocked β†’ output β†’ END
224
+ └─ continue β†’ router
225
+ β”œβ”€ rag β†’ llm
226
+ └─ tool/general β†’ llm
227
+ β”œβ”€ tool_executor
228
+ └─ memory β†’ hitl
229
+ β”œβ”€ evaluation
230
+ β”‚ β”œβ”€ retry β†’ llm
231
+ β”‚ └─ guardrails β†’ output
232
+ └─ END
233
  ```""")
234
 
235
  submit_outs = [chatbot, user_input, trace_display, meta_display, hitl_panel, hitl_content]
app/nodes/llm_node.py CHANGED
@@ -36,25 +36,33 @@ def llm_node(state: AgentState) -> AgentState:
36
  try:
37
  # Build system prompt
38
  system_parts = [
39
- "You are a helpful AI assistant. Answer ONLY the current query concisely.",
40
- "If a query is vague or has multiple possible meanings, ask ONE short clarifying question instead of assuming.",
41
- "Do not reference or answer previous questions.",
42
  ]
43
  if state.get("rag_context"):
44
  system_parts.append(f"\nUse the following context to answer:\n{state['rag_context']}")
45
  if state.get("memory_summary"):
46
- system_parts.append(f"\nFor background context only (do NOT repeat or expand on this):\n{state['memory_summary']}")
47
 
48
  system_msg = SystemMessage(content="\n".join(system_parts))
49
 
50
- # Always send only the current query β€” never full history
51
- # Full history causes the LLM to blend previous topics
52
- current_msg = HumanMessage(content=state["query"])
 
 
 
 
 
 
53
 
54
  if state["route"] == "tool":
55
- ai_msg = _llm_with_tools.invoke([system_msg, current_msg])
 
56
  else:
57
- ai_msg = llm.invoke([system_msg, current_msg])
 
58
 
59
  tool_calls = getattr(ai_msg, "tool_calls", []) or []
60
  response_text = ai_msg.content or ""
 
36
  try:
37
  # Build system prompt
38
  system_parts = [
39
+ "You are a helpful AI assistant.",
40
+ "Answer the current query using the conversation history for context.",
41
+ "Keep responses concise and relevant.",
42
  ]
43
  if state.get("rag_context"):
44
  system_parts.append(f"\nUse the following context to answer:\n{state['rag_context']}")
45
  if state.get("memory_summary"):
46
+ system_parts.append(f"\nConversation summary so far:\n{state['memory_summary']}")
47
 
48
  system_msg = SystemMessage(content="\n".join(system_parts))
49
 
50
+ # state["messages"] = prior safe history (from MemorySaver) + current HumanMessage
51
+ # Scrub tool noise, then build: [system, h1, a1, h2, a2, ..., current_query]
52
+ from langchain_core.messages import ToolMessage, AIMessage as AI
53
+ clean = [
54
+ m for m in state["messages"]
55
+ if not isinstance(m, ToolMessage)
56
+ and not (isinstance(m, AI) and getattr(m, "tool_calls", []))
57
+ ]
58
+ messages = [system_msg] + clean
59
 
60
  if state["route"] == "tool":
61
+ # Tool route: only current query to avoid re-firing old tool calls
62
+ ai_msg = _llm_with_tools.invoke([system_msg, HumanMessage(content=state["query"])])
63
  else:
64
+ # RAG / general: full clean history for context
65
+ ai_msg = llm.invoke(messages)
66
 
67
  tool_calls = getattr(ai_msg, "tool_calls", []) or []
68
  response_text = ai_msg.content or ""
app/nodes/output.py CHANGED
@@ -1,11 +1,38 @@
1
- """app/nodes/output.py β€” Final output node"""
2
- from langchain_core.messages import AIMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from app.state import AgentState
4
 
5
 
6
  def output_node(state: AgentState) -> AgentState:
7
- ai_message = AIMessage(content=state["response"])
8
- updated_messages = state["messages"] + [ai_message]
9
- log = state.get("node_log", []) + ["output"]
10
- print(f"\nπŸ€– {state['response']}\n")
11
- return {**state, "messages": updated_messages, "node_log": log}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app/nodes/output.py
3
+ ────────────────────
4
+ Single source of truth for message history.
5
+
6
+ Flow per turn:
7
+ gradio sends: messages=[] (empty β€” MemorySaver restores checkpoint history)
8
+ safety adds: HumanMessage(query) to messages
9
+ output_node:
10
+ - harmful/blocked β†’ drop the HumanMessage, keep prior history clean
11
+ - safe β†’ keep HumanMessage + append AIMessage(response)
12
+
13
+ MemorySaver then persists the updated messages for next turn.
14
+ """
15
+ from langchain_core.messages import AIMessage, HumanMessage
16
  from app.state import AgentState
17
 
18
 
19
  def output_node(state: AgentState) -> AgentState:
20
+ log = state.get("node_log", []) + ["output"]
21
+ response = state["response"]
22
+ messages = list(state["messages"])
23
+ is_harmful = state.get("is_harmful", False)
24
+ guardrail_ok = state.get("guardrail_passed", True)
25
+
26
+ if is_harmful or not guardrail_ok:
27
+ # Drop the HumanMessage for this turn β€” never pollute history
28
+ messages = [m for m in messages
29
+ if not (isinstance(m, HumanMessage) and m.content == state["query"])]
30
+ print(f"\nπŸ€– {response}\n")
31
+ print("[OUTPUT] Harmful turn scrubbed from history.")
32
+ else:
33
+ # Safe β€” HumanMessage already in messages (added by safety_node)
34
+ # Just append the assistant response
35
+ messages = messages + [AIMessage(content=response)]
36
+ print(f"\nπŸ€– {response}\n")
37
+
38
+ return {**state, "messages": messages, "node_log": log}
app/nodes/safety.py CHANGED
@@ -1,20 +1,12 @@
1
  """
2
  app/nodes/safety.py
3
  ────────────────────
4
- Production safety pattern: Constitutional AI approach.
 
 
5
 
6
- The LLM that generates responses is also the best judge of whether
7
- a query is harmful β€” it has full language understanding and context.
8
-
9
- Pattern used by: Anthropic, Cohere, Mistral AI
10
- β†’ System prompt + structured classification prompt
11
- β†’ Confidence threshold to avoid false positives
12
- β†’ Fail-open on errors (don't block legitimate users)
13
-
14
- Two checks:
15
- 1. Query intent β€” is the user asking for something harmful?
16
- 2. Response check β€” did the LLM accidentally produce harmful output?
17
- (guardrails.py handles this after response is generated)
18
  """
19
  import json
20
  from langchain_core.messages import SystemMessage, HumanMessage
@@ -23,7 +15,7 @@ from app.utils.llm import llm
23
 
24
  SAFE_FALLBACK = "I'm sorry, I can't help with that request."
25
 
26
- SYSTEM_PROMPT = """You are a content safety classifier.
27
  Your only job is to analyze user queries and determine if they request harmful content.
28
  Be precise β€” do not over-block. Only flag clearly harmful requests."""
29
 
@@ -32,7 +24,7 @@ CLASSIFICATION_PROMPT = """Analyze this query and classify it.
32
  HARMFUL categories (flag these):
33
  - Violence: instructions to physically harm, kill, or assault people
34
  - Unauthorized access: hacking, stealing credentials, bypassing security
35
- - Weapons: making bombs, explosives, illegal weapons
36
  - Illegal substances: synthesizing or manufacturing drugs
37
  - Harassment: stalking, doxxing, non-consensual tracking
38
  - Self-harm: methods to hurt oneself or suicide instructions
@@ -41,27 +33,35 @@ HARMFUL categories (flag these):
41
  NOT harmful (do not flag these):
42
  - Competitive contexts: "beat my friend at chess", "win an argument"
43
  - Security education: conceptual explanations of how attacks work
44
- - Fiction: stories, games, hypotheticals with dark themes
45
- - Legitimate anger: "I'm so frustrated I could scream"
46
  - Medical: drug interactions, symptoms, treatments
47
  - History/news: discussing past violent events
48
 
49
  Query: "{query}"
50
 
51
- Respond ONLY in this exact JSON format:
52
- {{"harmful": true/false, "category": "violence|hacking|weapons|drugs|harassment|self_harm|hate|safe", "confidence": 0.0-1.0, "reason": "one sentence"}}"""
53
 
54
 
55
  def safety_node(state: AgentState) -> AgentState:
56
- query = state.get("query", "")
57
- log = state.get("node_log", [])
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  try:
60
  response = llm.invoke([
61
  SystemMessage(content=SYSTEM_PROMPT),
62
  HumanMessage(content=CLASSIFICATION_PROMPT.format(query=query)),
63
  ])
64
-
65
  raw = response.content.strip().removeprefix("```json").removesuffix("```").strip()
66
  result = json.loads(raw)
67
 
@@ -70,18 +70,11 @@ def safety_node(state: AgentState) -> AgentState:
70
  category = result.get("category", "safe")
71
  reason = result.get("reason", "")
72
 
73
- # IST timestamp for every query
74
- from datetime import datetime, timezone, timedelta
75
- IST = timezone(timedelta(hours=5, minutes=30))
76
- ts = datetime.now(IST).strftime("%d %b %Y %I:%M:%S %p IST")
77
- print(f"[{ts}] [User Query] β€” {query}")
78
-
79
- # Require high confidence to avoid false positives on edge cases
80
- # e.g. "how to beat someone at chess" should NOT be blocked
81
  if harmful and confidence >= 0.85:
82
  print(f"[SAFETY] 🚫 Blocked β€” {category} ({confidence:.0%}): {reason}")
83
  return {
84
  **state,
 
85
  "is_harmful": True,
86
  "guardrail_passed": False,
87
  "response": SAFE_FALLBACK,
@@ -89,17 +82,24 @@ def safety_node(state: AgentState) -> AgentState:
89
  }
90
 
91
  if harmful and confidence < 0.85:
92
- # Ambiguous β€” let it through, LLM + guardrails will handle
93
- print(f"[SAFETY] ⚠️ Ambiguous ({category} {confidence:.0%}) β€” allowing through: {reason}")
94
 
95
  print(f"[SAFETY] βœ… Safe β€” {reason}")
96
- return {**state, "is_harmful": False, "node_log": log + ["safety βœ…"]}
 
 
 
 
 
97
 
98
  except Exception as e:
99
- # Fail-open: if classifier breaks, don't block legitimate users
100
- # Guardrails backstop will still catch harmful responses
101
- print(f"[SAFETY] Classifier error ({e}) β€” fail-open, passing through")
102
- return {**state, "is_harmful": False, "node_log": log + ["safety (error→allowed)"]}
 
 
 
103
 
104
 
105
  def safety_route(state: AgentState) -> str:
 
1
  """
2
  app/nodes/safety.py
3
  ────────────────────
4
+ FIRST node in graph. Does two things:
5
+ 1. Adds HumanMessage(query) to messages β€” so history is always updated here
6
+ 2. Classifies query intent β€” harmful queries short-circuit to output_node
7
 
8
+ output_node will scrub the HumanMessage if harmful.
9
+ output_node will keep it + add AIMessage if safe.
 
 
 
 
 
 
 
 
 
 
10
  """
11
  import json
12
  from langchain_core.messages import SystemMessage, HumanMessage
 
15
 
16
  SAFE_FALLBACK = "I'm sorry, I can't help with that request."
17
 
18
+ SYSTEM_PROMPT = """You are a content safety classifier.
19
  Your only job is to analyze user queries and determine if they request harmful content.
20
  Be precise β€” do not over-block. Only flag clearly harmful requests."""
21
 
 
24
  HARMFUL categories (flag these):
25
  - Violence: instructions to physically harm, kill, or assault people
26
  - Unauthorized access: hacking, stealing credentials, bypassing security
27
+ - Weapons: making bombs, explosives, illegal weapons
28
  - Illegal substances: synthesizing or manufacturing drugs
29
  - Harassment: stalking, doxxing, non-consensual tracking
30
  - Self-harm: methods to hurt oneself or suicide instructions
 
33
  NOT harmful (do not flag these):
34
  - Competitive contexts: "beat my friend at chess", "win an argument"
35
  - Security education: conceptual explanations of how attacks work
36
+ - Fiction/creative writing with dark themes
 
37
  - Medical: drug interactions, symptoms, treatments
38
  - History/news: discussing past violent events
39
 
40
  Query: "{query}"
41
 
42
+ JSON only: {{"harmful": true/false, "category": "violence|hacking|weapons|drugs|harassment|self_harm|hate|safe", "confidence": 0.0-1.0, "reason": "one sentence"}}"""
 
43
 
44
 
45
  def safety_node(state: AgentState) -> AgentState:
46
+ query = state.get("query", "")
47
+ messages = list(state.get("messages", []))
48
+ log = state.get("node_log", [])
49
+
50
+ # ── Add HumanMessage to history first ────────────────────────────────
51
+ # output_node will scrub it if harmful, keep it if safe
52
+ messages = messages + [HumanMessage(content=query)]
53
+
54
+ # ── IST timestamp ─────────────────────────────────────────────────────
55
+ from datetime import datetime, timezone, timedelta
56
+ IST = timezone(timedelta(hours=5, minutes=30))
57
+ ts = datetime.now(IST).strftime("%d %b %Y %I:%M:%S %p IST")
58
+ print(f"[{ts}] [User Query] β€” {query}")
59
 
60
  try:
61
  response = llm.invoke([
62
  SystemMessage(content=SYSTEM_PROMPT),
63
  HumanMessage(content=CLASSIFICATION_PROMPT.format(query=query)),
64
  ])
 
65
  raw = response.content.strip().removeprefix("```json").removesuffix("```").strip()
66
  result = json.loads(raw)
67
 
 
70
  category = result.get("category", "safe")
71
  reason = result.get("reason", "")
72
 
 
 
 
 
 
 
 
 
73
  if harmful and confidence >= 0.85:
74
  print(f"[SAFETY] 🚫 Blocked β€” {category} ({confidence:.0%}): {reason}")
75
  return {
76
  **state,
77
+ "messages": messages, # HumanMessage included β€” output_node will scrub
78
  "is_harmful": True,
79
  "guardrail_passed": False,
80
  "response": SAFE_FALLBACK,
 
82
  }
83
 
84
  if harmful and confidence < 0.85:
85
+ print(f"[SAFETY] ⚠️ Ambiguous ({category} {confidence:.0%}) β€” allowing: {reason}")
 
86
 
87
  print(f"[SAFETY] βœ… Safe β€” {reason}")
88
+ return {
89
+ **state,
90
+ "messages": messages,
91
+ "is_harmful": False,
92
+ "node_log": log + ["safety βœ…"],
93
+ }
94
 
95
  except Exception as e:
96
+ print(f"[SAFETY] Classifier error ({e}) β€” fail-open")
97
+ return {
98
+ **state,
99
+ "messages": messages,
100
+ "is_harmful": False,
101
+ "node_log": log + ["safety (error→allowed)"],
102
+ }
103
 
104
 
105
  def safety_route(state: AgentState) -> str: