Pawan Mane commited on
Commit
1d6b948
Β·
1 Parent(s): 18deb87

Memory Poisoning Error Fixising

Browse files
app/frontend/gradio_app_hf.py CHANGED
@@ -64,6 +64,7 @@ def run_graph(query: str) -> AgentState:
64
  "guardrail_passed": True,
65
  "memory_summary": "",
66
  "node_log": [],
 
67
  }
68
  return _graph.invoke(initial_state, config=_thread_config)
69
 
 
64
  "guardrail_passed": True,
65
  "memory_summary": "",
66
  "node_log": [],
67
+ "is_harmful": False,
68
  }
69
  return _graph.invoke(initial_state, config=_thread_config)
70
 
app/graph/builder.py CHANGED
@@ -1,39 +1,41 @@
1
  """
2
  app/graph/builder.py
3
  ─────────────────────
4
- Assembles the LangGraph StateGraph from all nodes and edges.
5
- This is the only file that knows about graph topology.
6
-
7
  Graph topology:
8
- β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
9
- β”Œβ”€β”€β”€β”€β–Ίβ”‚ rag │────┐
10
- β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
11
- [START] ─► router β–Ό
12
- β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13
- └────►│ llm (tool / general) β”‚
14
- β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
15
- β”‚ β”‚
16
- tool_calls? none
17
- β”‚ β”‚
18
- tool_executor β”‚
19
- β”‚ β”‚
20
- β–Ό β–Ό
21
- memory β—„β”€β”€β”€β”€β”˜
22
- β”‚
23
- hitl ──(rejected)──► END
24
- β”‚
25
- evaluation ──(retry)──► llm
26
- β”‚
27
- guardrails
28
- β”‚
29
- output
30
- β”‚
31
- END
 
 
 
 
32
  """
33
 
34
  from langgraph.graph import StateGraph, END
35
  from langgraph.checkpoint.memory import MemorySaver
36
  from app.state import AgentState
 
37
  from app.nodes import (
38
  router_node, route_selector,
39
  rag_node,
@@ -52,6 +54,7 @@ def build_graph():
52
  builder = StateGraph(AgentState)
53
 
54
  # ── Register nodes ────────────────────────────────────────────────────
 
55
  builder.add_node("router", router_node)
56
  builder.add_node("rag", rag_node)
57
  builder.add_node("llm", llm_node)
@@ -62,10 +65,17 @@ def build_graph():
62
  builder.add_node("guardrails", guardrails_node)
63
  builder.add_node("output", output_node)
64
 
65
- # ── Entry point ───────────────────────────────────────────────────────
66
- builder.set_entry_point("router")
 
 
 
 
 
 
 
67
 
68
- # ── Conditional routing (CHECKPOINT 3) ────────────────────────────────
69
  builder.add_conditional_edges(
70
  "router",
71
  route_selector,
 
1
  """
2
  app/graph/builder.py
3
  ─────────────────────
 
 
 
4
  Graph topology:
5
+
6
+ [START] ─► safety ──(blocked)──► output ──► END
7
+ β”‚
8
+ (continue)
9
+ β”‚
10
+ router
11
+ / | \
12
+ rag | tool/general
13
+ \ | /
14
+ β–Ό β–Ό β–Ό
15
+ llm
16
+ / \
17
+ tool_calls? none
18
+ | |
19
+ tool_executor |
20
+ | |
21
+ β–Ό β–Ό
22
+ memory
23
+ |
24
+ hitl ──(rejected)──► END
25
+ |
26
+ evaluation ──(retry)──► llm
27
+ |
28
+ guardrails
29
+ |
30
+ output
31
+ |
32
+ END
33
  """
34
 
35
  from langgraph.graph import StateGraph, END
36
  from langgraph.checkpoint.memory import MemorySaver
37
  from app.state import AgentState
38
+ from app.nodes.safety import safety_node, safety_route
39
  from app.nodes import (
40
  router_node, route_selector,
41
  rag_node,
 
54
  builder = StateGraph(AgentState)
55
 
56
  # ── Register nodes ────────────────────────────────────────────────────
57
+ builder.add_node("safety", safety_node)
58
  builder.add_node("router", router_node)
59
  builder.add_node("rag", rag_node)
60
  builder.add_node("llm", llm_node)
 
65
  builder.add_node("guardrails", guardrails_node)
66
  builder.add_node("output", output_node)
67
 
68
+ # ── Entry: safety first ───────────────────────────────────────────────
69
+ builder.set_entry_point("safety")
70
+
71
+ # Safety gate β€” blocked queries skip everything and go straight to output
72
+ builder.add_conditional_edges(
73
+ "safety",
74
+ safety_route,
75
+ {"blocked": "output", "continue": "router"},
76
+ )
77
 
78
+ # ── Routing ───────────────────────────────────────────────────────────
79
  builder.add_conditional_edges(
80
  "router",
81
  route_selector,
app/nodes/evaluation.py CHANGED
@@ -1,65 +1,40 @@
1
  """
2
  app/nodes/evaluation.py β€” CHECKPOINT 7: Evaluation
3
 
4
- Auto-pass cases (no LLM scoring needed):
5
- 1. Tool responses β€” always valid
6
- 2. LLM refusals β€” intentional, let guardrails handle
7
- 3. General/casual route β€” jokes, greetings, opinions don't need quality scoring
8
- (evaluator would unfairly penalise short creative responses)
 
 
9
  """
10
  from langchain_core.messages import HumanMessage
11
  from app.state import AgentState
12
  from app.utils.llm import llm
13
  from app.config import settings
14
 
15
- # Phrases that indicate the LLM intentionally refused β€” don't retry these
16
- REFUSAL_PHRASES = [
17
- "sensitive", "harmful", "hate", "threat", "negative", "i can't help with that."
18
- "i cannot provide information on",
19
- "i can't help", "i cannot help", "i'm unable", "i am unable",
20
- "i won't", "i will not", "not able to assist", "can't assist",
21
- "cannot assist", "i'm sorry, i can't", "i'm not able",
22
- "i must decline", "i'd rather not", "i don't think i should",
23
- "i cannot provide", "i can't provide", "i'm not comfortable",
24
- "that's not something i", "not something i can",
25
- "i'm an ai", "as an ai",
26
- ]
27
-
28
-
29
- def _is_refusal(response: str) -> bool:
30
- low = response.lower()
31
- return any(phrase in low for phrase in REFUSAL_PHRASES)
32
-
33
 
34
  def evaluation_node(state: AgentState) -> AgentState:
35
  log = state.get("node_log", [])
36
  response = state.get("response", "")
37
  route = state.get("route", "")
38
 
39
- # ── 1. Tool responses β€” always valid ──────────────────────────────────
40
  if route == "tool" or state.get("tool_results"):
41
  print("[EVAL] Tool response β€” auto-passed.")
42
  return {**state, "evaluation_score": 1.0,
43
  "node_log": log + ["evaluation (tool auto-pass βœ…)"]}
44
 
45
- # ── 2. General/casual β€” auto-pass ─────────────────────────────────────
46
- # Jokes, greetings, opinions are subjective β€” LLM scorer would unfairly
47
- # give 0.2 to a perfectly good joke. Let guardrails handle safety.
48
  if route == "general":
49
  print("[EVAL] General route β€” auto-passed.")
50
  return {**state, "evaluation_score": 1.0,
51
  "node_log": log + ["evaluation (general auto-pass βœ…)"]}
52
 
53
- # ── 3. LLM refusal β€” skip retries, forward to guardrails ──────────────
54
- if _is_refusal(response):
55
- print(f"[EVAL] Refusal detected β†’ skipping retries.")
56
- return {**state, "evaluation_score": 1.0,
57
- "node_log": log + ["evaluation (refusal auto-pass βœ…)"]}
58
-
59
- # ── 4. RAG responses β€” score with LLM ─────────────────────────────────
60
- eval_prompt = f"""Rate the following AI response on a scale of 0.0 to 1.0
61
- for relevance and quality relative to the query.
62
- Return ONLY a float number between 0.0 and 1.0 β€” no other text.
63
 
64
  Query: {state['query']}
65
  Response: {response}
@@ -70,29 +45,27 @@ Score:"""
70
  raw = llm.invoke([HumanMessage(content=eval_prompt)]).content.strip()
71
  score = max(0.0, min(1.0, float(raw)))
72
  except Exception:
73
- score = 0.8
74
 
75
  current_retries = state.get("retry_count", 0)
76
- below_threshold = score < settings.EVAL_THRESHOLD
77
- new_retry_count = (current_retries + 1) if below_threshold else current_retries
78
 
79
  print(f"[EVAL] Score: {score:.2f} (threshold: {settings.EVAL_THRESHOLD}, retries: {current_retries})")
80
  return {
81
  **state,
82
  "evaluation_score": score,
83
- "retry_count": new_retry_count,
84
- "node_log": log + [f"evaluation (score={score:.2f}, retry={new_retry_count})"],
85
  }
86
 
87
 
88
  def eval_route(state: AgentState) -> str:
89
  score = state["evaluation_score"]
90
  retry_count = state.get("retry_count", 0)
91
-
92
  if score < settings.EVAL_THRESHOLD and retry_count <= settings.MAX_RETRIES:
93
  print(f"[EVAL] Score {score:.2f} below threshold β€” retry {retry_count}/{settings.MAX_RETRIES}")
94
  return "retry"
95
-
96
  if score < settings.EVAL_THRESHOLD:
97
- print(f"[EVAL] Max retries ({settings.MAX_RETRIES}) reached β€” proceeding anyway.")
98
  return "guardrails"
 
1
  """
2
  app/nodes/evaluation.py β€” CHECKPOINT 7: Evaluation
3
 
4
+ Clean, simple quality scoring with no hardcoded safety logic.
5
+ Safety is handled entirely by safety_node (before the graph runs)
6
+ and guardrails_node (after the response is generated).
7
+
8
+ Auto-pass cases:
9
+ - tool route: tool outputs are always valid, not prose
10
+ - general route: subjective responses (jokes, opinions) can't be scored fairly
11
  """
12
  from langchain_core.messages import HumanMessage
13
  from app.state import AgentState
14
  from app.utils.llm import llm
15
  from app.config import settings
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def evaluation_node(state: AgentState) -> AgentState:
19
  log = state.get("node_log", [])
20
  response = state.get("response", "")
21
  route = state.get("route", "")
22
 
23
+ # Tool outputs are structured data, not prose β€” always pass
24
  if route == "tool" or state.get("tool_results"):
25
  print("[EVAL] Tool response β€” auto-passed.")
26
  return {**state, "evaluation_score": 1.0,
27
  "node_log": log + ["evaluation (tool auto-pass βœ…)"]}
28
 
29
+ # General/casual β€” subjective, not scoreable fairly
 
 
30
  if route == "general":
31
  print("[EVAL] General route β€” auto-passed.")
32
  return {**state, "evaluation_score": 1.0,
33
  "node_log": log + ["evaluation (general auto-pass βœ…)"]}
34
 
35
+ # RAG responses β€” score relevance and quality
36
+ eval_prompt = f"""Rate this AI response for relevance and quality on a scale of 0.0 to 1.0.
37
+ Return ONLY a float between 0.0 and 1.0 β€” nothing else.
 
 
 
 
 
 
 
38
 
39
  Query: {state['query']}
40
  Response: {response}
 
45
  raw = llm.invoke([HumanMessage(content=eval_prompt)]).content.strip()
46
  score = max(0.0, min(1.0, float(raw)))
47
  except Exception:
48
+ score = 0.8 # safe fallback
49
 
50
  current_retries = state.get("retry_count", 0)
51
+ below = score < settings.EVAL_THRESHOLD
52
+ new_retries = (current_retries + 1) if below else current_retries
53
 
54
  print(f"[EVAL] Score: {score:.2f} (threshold: {settings.EVAL_THRESHOLD}, retries: {current_retries})")
55
  return {
56
  **state,
57
  "evaluation_score": score,
58
+ "retry_count": new_retries,
59
+ "node_log": log + [f"evaluation (score={score:.2f}, retry={new_retries})"],
60
  }
61
 
62
 
63
  def eval_route(state: AgentState) -> str:
64
  score = state["evaluation_score"]
65
  retry_count = state.get("retry_count", 0)
 
66
  if score < settings.EVAL_THRESHOLD and retry_count <= settings.MAX_RETRIES:
67
  print(f"[EVAL] Score {score:.2f} below threshold β€” retry {retry_count}/{settings.MAX_RETRIES}")
68
  return "retry"
 
69
  if score < settings.EVAL_THRESHOLD:
70
+ print(f"[EVAL] Max retries reached β€” proceeding.")
71
  return "guardrails"
app/nodes/guardrails.py CHANGED
@@ -1,18 +1,44 @@
1
- """app/nodes/guardrails.py β€” CHECKPOINT 8: Guardrails"""
 
 
 
 
 
 
 
2
  from app.state import AgentState
3
- from app.config import settings
4
 
5
  SAFE_FALLBACK = "I'm sorry, I can't help with that request."
6
 
 
 
 
 
 
7
 
8
  def guardrails_node(state: AgentState) -> AgentState:
9
- response_lower = state.get("response", "").lower()
10
- triggered = [p for p in settings.BLOCKED_PHRASES if p in response_lower]
11
- log = state.get("node_log", [])
12
- if triggered:
13
- print(f"[GUARDRAILS] ⚠️ Blocked β€” matched phrases: {triggered}")
14
- log = log + [f"guardrails (BLOCKED: {triggered})"]
15
- return {**state, "guardrail_passed": False, "response": SAFE_FALLBACK, "node_log": log}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  print("[GUARDRAILS] βœ… Passed.")
17
- log = log + ["guardrails βœ…"]
18
- return {**state, "guardrail_passed": True, "node_log": log}
 
1
+ """
2
+ app/nodes/guardrails.py β€” CHECKPOINT 8: Guardrails
3
+
4
+ Simplified β€” safety_node already handled query-level intent.
5
+ This node only checks the final response for accidental harmful content
6
+ using the LLM (not keywords) as a final backstop.
7
+ """
8
+ from langchain_core.messages import HumanMessage
9
  from app.state import AgentState
10
+ from app.utils.llm import llm
11
 
12
  SAFE_FALLBACK = "I'm sorry, I can't help with that request."
13
 
14
+ GUARDRAIL_PROMPT = """Does the following AI response contain harmful, dangerous, or illegal instructions?
15
+ Answer ONLY with "yes" or "no".
16
+
17
+ Response: """
18
+
19
 
20
  def guardrails_node(state: AgentState) -> AgentState:
21
+ response = state.get("response", "")
22
+ log = state.get("node_log", [])
23
+
24
+ # If already blocked by safety_node, just pass through
25
+ if not state.get("guardrail_passed", True) or state.get("is_harmful"):
26
+ return state
27
+
28
+ try:
29
+ check = llm.invoke([HumanMessage(content=GUARDRAIL_PROMPT + response)])
30
+ verdict = check.content.strip().lower()
31
+
32
+ if verdict.startswith("yes"):
33
+ print("[GUARDRAILS] ⚠️ Response contains harmful content β€” blocked.")
34
+ return {
35
+ **state,
36
+ "guardrail_passed": False,
37
+ "response": SAFE_FALLBACK,
38
+ "node_log": log + ["guardrails (RESPONSE BLOCKED ⚠️)"],
39
+ }
40
+ except Exception as e:
41
+ print(f"[GUARDRAILS] Check failed ({e}) β€” passing through.")
42
+
43
  print("[GUARDRAILS] βœ… Passed.")
44
+ return {**state, "guardrail_passed": True, "node_log": log + ["guardrails βœ…"]}
 
app/nodes/llm_node.py CHANGED
@@ -43,6 +43,9 @@ def llm_node(state: AgentState) -> AgentState:
43
 
44
  system_msg = SystemMessage(content="\n".join(system_parts))
45
 
 
 
 
46
  # Always send only the current query β€” never full history
47
  # Full history causes the LLM to blend previous topics
48
  current_msg = HumanMessage(content=state["query"])
@@ -55,6 +58,10 @@ def llm_node(state: AgentState) -> AgentState:
55
  tool_calls = getattr(ai_msg, "tool_calls", []) or []
56
  response_text = ai_msg.content or ""
57
 
 
 
 
 
58
  print(f"[LLM] Attempt {attempt} succeeded. Tool calls: {len(tool_calls)}")
59
  print(f"[LLM] Generated Output for User Query ({state['query'][:60]}) : {response_text[:200]}")
60
  log = state.get("node_log", []) + [f"llm (attempt={attempt}, route={state['route']})"]
 
43
 
44
  system_msg = SystemMessage(content="\n".join(system_parts))
45
 
46
+ print("\n\n\n")
47
+ print(system_msg)
48
+ print("\n\n\n")
49
  # Always send only the current query β€” never full history
50
  # Full history causes the LLM to blend previous topics
51
  current_msg = HumanMessage(content=state["query"])
 
58
  tool_calls = getattr(ai_msg, "tool_calls", []) or []
59
  response_text = ai_msg.content or ""
60
 
61
+ print("\n\n\n")
62
+ print(response_text)
63
+ print("\n\n\n")
64
+
65
  print(f"[LLM] Attempt {attempt} succeeded. Tool calls: {len(tool_calls)}")
66
  print(f"[LLM] Generated Output for User Query ({state['query'][:60]}) : {response_text[:200]}")
67
  log = state.get("node_log", []) + [f"llm (attempt={attempt}, route={state['route']})"]
app/nodes/safety.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app/nodes/safety.py β€” SAFETY CLASSIFIER (runs first in graph)
3
+
4
+ Single source of truth for all safety decisions.
5
+ Uses the LLM to dynamically judge intent β€” no hardcoded phrase lists.
6
+
7
+ Returns:
8
+ is_harmful: bool β€” True if query should be blocked
9
+ reason: str β€” why it was flagged
10
+ safe_response: str β€” what to tell the user if blocked
11
+
12
+ By doing this ONCE at the start, every other node stays clean:
13
+ - router β†’ just routes, no safety logic
14
+ - llm_node β†’ just generates, no safety logic
15
+ - evaluation β†’ just scores quality, no safety logic
16
+ - guardrails β†’ just checks response quality, no safety logic
17
+ """
18
+ import json
19
+ from langchain_core.messages import HumanMessage
20
+ from app.state import AgentState
21
+ from app.utils.llm import llm
22
+
23
+ SAFE_FALLBACK = "I'm sorry, I can't help with that request."
24
+
25
+ SAFETY_PROMPT = """You are a safety classifier for an AI assistant.
26
+
27
+ Analyze the user query and determine if it requests harmful, illegal, or dangerous content.
28
+
29
+ Categories to BLOCK:
30
+ - Violence: instructions to physically harm, attack, beat, assault, or kill people or animals
31
+ - Hacking: unauthorized access, stealing credentials, bypassing security systems
32
+ - Weapons: making bombs, explosives, or illegal weapons
33
+ - Illegal drugs: synthesizing or manufacturing controlled substances
34
+ - Harassment: stalking, doxxing, spying on people without consent
35
+ - Self-harm: methods to harm oneself or commit suicide
36
+ - Hate speech: content targeting people based on race, religion, gender, etc.
37
+
38
+ Do NOT block:
39
+ - Competitive contexts ("how to beat my friend at chess")
40
+ - Security education ("explain how SQL injection works conceptually")
41
+ - Medical questions ("what drugs interact with aspirin")
42
+ - Fiction/creative writing with dark themes
43
+ - News/historical discussion of violence
44
+ - Legitimate anger expressions ("I'm so mad I could scream")
45
+
46
+ Think carefully about INTENT and CONTEXT before deciding.
47
+
48
+ Respond ONLY with valid JSON:
49
+ {
50
+ "is_harmful": true or false,
51
+ "category": "violence|hacking|weapons|drugs|harassment|self_harm|hate_speech|safe",
52
+ "reason": "one sentence explanation",
53
+ "confidence": 0.0 to 1.0
54
+ }
55
+
56
+ User query: """
57
+
58
+
59
+ def safety_node(state: AgentState) -> AgentState:
60
+ query = state.get("query", "")
61
+ log = state.get("node_log", [])
62
+
63
+ try:
64
+ response = llm.invoke([HumanMessage(content=SAFETY_PROMPT + query)])
65
+ raw = response.content.strip().removeprefix("```json").removesuffix("```").strip()
66
+ result = json.loads(raw)
67
+
68
+ is_harmful = result.get("is_harmful", False)
69
+ category = result.get("category", "safe")
70
+ reason = result.get("reason", "")
71
+ confidence = float(result.get("confidence", 0.0))
72
+
73
+ if is_harmful and confidence >= 0.7:
74
+ print(f"[SAFETY] 🚫 Blocked β€” {category} ({confidence:.0%}): {reason}")
75
+ return {
76
+ **state,
77
+ "is_harmful": True,
78
+ "guardrail_passed": False,
79
+ "response": SAFE_FALLBACK,
80
+ "node_log": log + [f"safety (BLOCKED: {category} {confidence:.0%})"],
81
+ }
82
+
83
+ print(f"[SAFETY] βœ… Safe β€” {reason}")
84
+ return {
85
+ **state,
86
+ "is_harmful": False,
87
+ "node_log": log + ["safety βœ…"],
88
+ }
89
+
90
+ except Exception as e:
91
+ # On parse failure, allow through β€” LLM + guardrails still backstop
92
+ print(f"[SAFETY] Parse error ({e}) β€” allowing through")
93
+ return {**state, "is_harmful": False, "node_log": log + ["safety (parse error β€” allowed)"]}
94
+
95
+
96
+ def safety_route(state: AgentState) -> str:
97
+ """Edge function β€” short-circuit to output if harmful."""
98
+ return "blocked" if state.get("is_harmful") else "continue"
app/state.py CHANGED
@@ -21,4 +21,5 @@ class AgentState(TypedDict):
21
  evaluation_score: float # LLM self-evaluation score 0.0 – 1.0
22
  guardrail_passed: bool # Did the safety filter pass?
23
  memory_summary: str # Compressed summary of older conversation turns
24
- node_log: List[str] # Ordered list of nodes visited β€” shown in Gradio UI
 
 
21
  evaluation_score: float # LLM self-evaluation score 0.0 – 1.0
22
  guardrail_passed: bool # Did the safety filter pass?
23
  memory_summary: str # Compressed summary of older conversation turns
24
+ node_log: List[str] # Ordered list of nodes visited β€” shown in Gradio UI
25
+ is_harmful: bool # Did safety_node flag this query as harmful?