AgentMask / graph.py
b2230765034
Switch from Anthropic to HuggingFace Qwen model (free tier)
febd4c2
"""
LangGraph State Machine for Secure Reasoning MCP Server
Implements the Chain-of-Checks workflow with cryptographic logging.
"""
import json
import os
import re
from huggingface_hub import InferenceClient
from typing import Literal
from datetime import datetime
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from state import AgentState
from schemas import (
ExecutionPlan, StepPlan, SafetyCheckResult, ExecutionResult,
Justification, CryptoLogEntry, HashRequest, MerkleUpdateRequest,
WORMWriteRequest
)
from prompts import (
format_planner_prompt, format_safety_prompt, format_executor_prompt,
format_justification_prompt, format_synthesis_prompt
)
from mock_tools import MockCryptoTools
# ============================================================================
# GLOBAL CONFIGURATION
# ============================================================================
# HuggingFace Inference Client - Ücretsiz tier kullanıyor
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "Qwen/Qwen2.5-72B-Instruct" # Güçlü ve ücretsiz
hf_client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
# Initialize crypto tools
crypto_tools = MockCryptoTools()
# ============================================================================
# HuggingFace LLM Wrapper - LangChain benzeri interface
# ============================================================================
class HuggingFaceLLM:
"""
HuggingFace Inference API wrapper that mimics LangChain LLM interface.
"""
def __init__(self, client: InferenceClient):
self.client = client
def invoke(self, messages: list) -> "LLMResponse":
"""
Call HuggingFace model with messages.
Returns object with .content attribute like LangChain.
"""
# Convert LangChain messages to HF format
hf_messages = []
for msg in messages:
if isinstance(msg, SystemMessage):
hf_messages.append({"role": "system", "content": msg.content})
elif isinstance(msg, HumanMessage):
hf_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
hf_messages.append({"role": "assistant", "content": msg.content})
try:
response = self.client.chat_completion(
messages=hf_messages,
max_tokens=2048,
temperature=0.1, # Low for deterministic reasoning
stream=False
)
content = response.choices[0].message.content
return LLMResponse(content)
except Exception as e:
print(f"HF API Error: {e}")
return LLMResponse(f'{{"error": "{str(e)}"}}')
class LLMResponse:
"""Simple response object with content attribute."""
def __init__(self, content: str):
self.content = content
# Initialize our HF-based LLM
llm = HuggingFaceLLM(hf_client)
# ============================================================================
# NODE 1: PLANNER
# ============================================================================
def planner_node(state: AgentState) -> AgentState:
"""
Generate a step-by-step execution plan for the task.
Args:
state: Current agent state with the task
Returns:
Updated state with the execution plan
"""
print(f"\n{'='*60}")
print(f"🧠 PLANNER NODE - Generating execution plan")
print(f"{'='*60}")
# Format the prompt
prompts = format_planner_prompt(state["task"])
# Create messages
messages = [
SystemMessage(content=prompts["system"]),
HumanMessage(content=prompts["user"])
]
# Call LLM
response = llm.invoke(messages)
# Parse JSON response
try:
plan_data = json.loads(response.content)
# Convert to ExecutionPlan model
steps = [StepPlan(**step) for step in plan_data["steps"]]
plan = ExecutionPlan(
steps=steps,
total_steps=plan_data["total_steps"]
)
print(f"✅ Generated plan with {plan.total_steps} steps:")
for step in steps:
print(f" Step {step.step_number}: {step.action}")
# Update state
state["plan"] = plan
state["current_step_index"] = 0
state["status"] = "executing"
state["messages"].extend([
HumanMessage(content=prompts["user"]),
AIMessage(content=response.content)
])
return state
except json.JSONDecodeError as e:
print(f"❌ Failed to parse planner response: {e}")
state["error"] = f"Planner failed to generate valid JSON: {str(e)}"
state["status"] = "failed"
return state
# ============================================================================
# NODE 2: SAFETY CHECKER
# ============================================================================
def safety_node(state: AgentState) -> AgentState:
"""
Validate that the current step is safe to execute.
Args:
state: Current agent state with the plan
Returns:
Updated state with safety validation result
"""
print(f"\n{'='*60}")
print(f"🛡️ SAFETY NODE - Validating step {state['current_step_index'] + 1}")
print(f"{'='*60}")
# Get current step
current_step = state["plan"].steps[state["current_step_index"]]
# Format previous steps for context
previous_steps = "None"
if state["current_step_index"] > 0:
prev_steps_list = [
f"Step {i+1}: {state['plan'].steps[i].action}"
for i in range(state["current_step_index"])
]
previous_steps = "\n".join(prev_steps_list)
# Format the prompt
prompts = format_safety_prompt(
step_description=current_step.action,
task=state["task"],
step_number=state["current_step_index"] + 1,
total_steps=state["plan"].total_steps,
previous_steps=previous_steps,
additional_context="This is a secure reasoning system with cryptographic logging."
)
# Create messages
messages = [
SystemMessage(content=prompts["system"]),
HumanMessage(content=prompts["user"])
]
# Call LLM
response = llm.invoke(messages)
# Parse JSON response
try:
safety_data = json.loads(response.content)
safety_result = SafetyCheckResult(**safety_data)
print(f"🔍 Safety Check Result:")
print(f" Is Safe: {safety_result.is_safe}")
print(f" Risk Level: {safety_result.risk_level}")
print(f" Reasoning: {safety_result.reasoning[:100]}...")
# Update state
state["safety_status"] = safety_result
state["messages"].extend([
HumanMessage(content=prompts["user"]),
AIMessage(content=response.content)
])
# Mark if blocked
if not safety_result.is_safe:
state["safety_blocked"] = True
print(f"🚫 Step BLOCKED due to safety concerns")
else:
print(f"✅ Step approved for execution")
return state
except json.JSONDecodeError as e:
print(f"❌ Failed to parse safety response: {e}")
# Default to blocking if parsing fails (fail-safe)
state["safety_status"] = SafetyCheckResult(
is_safe=False,
risk_level="critical",
reasoning=f"Safety check failed due to parsing error: {str(e)}",
blocked_reasons=["parsing_error"]
)
state["safety_blocked"] = True
return state
# ============================================================================
# NODE 3: EXECUTOR
# ============================================================================
def executor_node(state: AgentState) -> AgentState:
"""
Execute the current step (call tools if needed).
Args:
state: Current agent state with approved step
Returns:
Updated state with execution result
"""
print(f"\n{'='*60}")
print(f"⚡ EXECUTOR NODE - Executing step {state['current_step_index'] + 1}")
print(f"{'='*60}")
# Get current step
current_step = state["plan"].steps[state["current_step_index"]]
# Format previous results for context
previous_results = "None"
if state["justifications"]:
prev_results_list = [
f"Step {j.step_number}: {j.reasoning[:100]}..."
for j in state["justifications"][-3:] # Last 3 steps
]
previous_results = "\n".join(prev_results_list)
# Format the prompt
prompts = format_executor_prompt(
step_description=current_step.action,
task=state["task"],
expected_outcome=current_step.expected_outcome,
requires_tools=current_step.requires_tools,
previous_results=previous_results
)
# Create messages
messages = [
SystemMessage(content=prompts["system"]),
HumanMessage(content=prompts["user"])
]
# Call LLM
response = llm.invoke(messages)
# Parse JSON response
try:
executor_data = json.loads(response.content)
tool_needed = executor_data.get("tool_needed", "internal_reasoning")
tool_params = executor_data.get("tool_params")
direct_result = executor_data.get("direct_result")
print(f"🔧 Tool Selection: {tool_needed}")
# Execute based on tool selection
if tool_needed == "internal_reasoning":
result = ExecutionResult(
success=True,
output=direct_result or "Analysis completed through reasoning",
tool_calls=["internal_reasoning"]
)
else:
# Simulate tool execution (in real system, dispatch to actual tools)
result = ExecutionResult(
success=True,
output=f"Simulated result from {tool_needed} with params: {tool_params}",
tool_calls=[tool_needed]
)
print(f"✅ Execution successful")
print(f" Output: {str(result.output)[:100]}...")
# Update state
state["execution_result"] = result
state["messages"].extend([
HumanMessage(content=prompts["user"]),
AIMessage(content=response.content)
])
return state
except json.JSONDecodeError as e:
print(f"❌ Execution failed: {e}")
state["execution_result"] = ExecutionResult(
success=False,
output=None,
error=f"Failed to parse executor response: {str(e)}",
tool_calls=[]
)
return state
except Exception as e:
print(f"❌ Execution error: {e}")
state["execution_result"] = ExecutionResult(
success=False,
output=None,
error=str(e),
tool_calls=[]
)
return state
# ============================================================================
# NODE 4: LOGGER (Cryptographic Logging)
# ============================================================================
def logger_node(state: AgentState) -> AgentState:
"""
Hash the execution result and log it to Merkle Tree + WORM storage.
Args:
state: Current agent state with execution result
Returns:
Updated state with cryptographic log entry
"""
print(f"\n{'='*60}")
print(f"📝 LOGGER NODE - Creating cryptographic proof")
print(f"{'='*60}")
current_step = state["plan"].steps[state["current_step_index"]]
execution_result = state["execution_result"]
try:
# 1. Prepare the data to log
log_data = {
"task_id": state["task_id"],
"step_number": state["current_step_index"] + 1,
"action": current_step.action,
"result": execution_result.output if execution_result.success else execution_result.error,
"timestamp": datetime.utcnow().isoformat(),
"safety_approved": state["safety_status"].is_safe if state["safety_status"] else False
}
# 2. Hash the action data
hash_request = HashRequest(
data=json.dumps(log_data, sort_keys=True),
algorithm="sha256"
)
hash_response = crypto_tools.hash_tool(hash_request)
action_hash = hash_response.hash
print(f"🔐 Action Hash: {action_hash[:16]}...")
# 3. Update Merkle Tree
merkle_request = MerkleUpdateRequest(
leaf_hash=action_hash,
metadata={"step": state["current_step_index"] + 1}
)
merkle_response = crypto_tools.merkle_update_tool(merkle_request)
merkle_root = merkle_response.merkle_root
print(f"🌳 Merkle Root: {merkle_root[:16]}...")
# 4. Write to WORM storage
entry_id = f"{state['task_id']}_step_{state['current_step_index'] + 1}"
worm_request = WORMWriteRequest(
entry_id=entry_id,
data=log_data,
merkle_root=merkle_root
)
worm_response = crypto_tools.worm_write_tool(worm_request)
print(f"💾 WORM Path: {worm_response.storage_path}")
# 5. Create log entry
log_entry = CryptoLogEntry(
step_number=state["current_step_index"] + 1,
action_hash=action_hash,
merkle_root=merkle_root,
worm_path=worm_response.storage_path
)
# Update state
state["logs"].append(log_entry)
print(f"✅ Cryptographic logging complete")
return state
except Exception as e:
print(f"❌ Logging failed: {e}")
state["error"] = f"Cryptographic logging failed: {str(e)}"
return state
# ============================================================================
# NODE 5: JUSTIFICATION
# ============================================================================
def justification_node(state: AgentState) -> AgentState:
"""
Generate an explanation for why the action was taken.
Args:
state: Current agent state with execution result
Returns:
Updated state with justification
"""
print(f"\n{'='*60}")
print(f"💭 JUSTIFICATION NODE - Explaining the action")
print(f"{'='*60}")
current_step = state["plan"].steps[state["current_step_index"]]
execution_result = state["execution_result"]
# Determine tool used
tool_used = ", ".join(execution_result.tool_calls) if execution_result.tool_calls else "none"
# Format the prompt
prompts = format_justification_prompt(
step_description=current_step.action,
tool_used=tool_used,
execution_result=str(execution_result.output)[:500] if execution_result.success else execution_result.error,
task=state["task"],
step_number=state["current_step_index"] + 1,
total_steps=state["plan"].total_steps,
expected_outcome=current_step.expected_outcome
)
# Create messages
messages = [
SystemMessage(content=prompts["system"]),
HumanMessage(content=prompts["user"])
]
# Call LLM
response = llm.invoke(messages)
# Parse JSON response
try:
justification_data = json.loads(response.content)
justification = Justification(**justification_data)
print(f"📋 Justification generated:")
print(f" {justification.reasoning[:150]}...")
# Update state
state["justifications"].append(justification)
state["messages"].extend([
HumanMessage(content=prompts["user"]),
AIMessage(content=response.content)
])
return state
except json.JSONDecodeError as e:
print(f"⚠️ Failed to parse justification, using fallback: {e}")
# Create fallback justification
fallback = Justification(
step_number=state["current_step_index"] + 1,
reasoning=f"Executed {current_step.action} as planned. Result: {execution_result.success}",
evidence=None,
alternatives_considered=None
)
state["justifications"].append(fallback)
return state
# ============================================================================
# NODE 6: STEP ITERATOR
# ============================================================================
def step_iterator_node(state: AgentState) -> AgentState:
"""
Move to the next step or complete the task.
Args:
state: Current agent state
Returns:
Updated state with incremented step index
"""
print(f"\n{'='*60}")
print(f"➡️ STEP ITERATOR - Moving to next step")
print(f"{'='*60}")
# Increment step index
state["current_step_index"] += 1
# Check if we're done
if state["current_step_index"] >= state["plan"].total_steps:
print(f"🎉 All steps completed!")
state["status"] = "completed"
else:
print(f"📍 Moving to step {state['current_step_index'] + 1}/{state['plan'].total_steps}")
return state
# ============================================================================
# NODE 7: REFINER (for unsafe steps)
# ============================================================================
def refiner_node(state: AgentState) -> AgentState:
"""
Handle unsafe steps by modifying or skipping them.
Args:
state: Current agent state with blocked step
Returns:
Updated state with refinement decision
"""
print(f"\n{'='*60}")
print(f"🔧 REFINER NODE - Handling unsafe step")
print(f"{'='*60}")
current_step = state["plan"].steps[state["current_step_index"]]
safety_status = state["safety_status"]
# Log the blocked action
print(f"🚫 Step blocked: {current_step.action}")
print(f" Reason: {safety_status.reasoning}")
# Create a null execution result
state["execution_result"] = ExecutionResult(
success=False,
output=None,
error=f"Step blocked by safety guardrails: {safety_status.reasoning}",
tool_calls=[]
)
# Create justification for blocking
justification = Justification(
step_number=state["current_step_index"] + 1,
reasoning=f"Step was blocked by safety guardrails. Risk level: {safety_status.risk_level}. Reason: {safety_status.reasoning}",
evidence=safety_status.blocked_reasons or [],
alternatives_considered=["Skip this step", "Abort entire task"]
)
state["justifications"].append(justification)
# Mark status
state["status"] = "blocked"
print(f"⚠️ Task blocked due to safety concerns")
return state
# ============================================================================
# CONDITIONAL EDGES
# ============================================================================
def should_execute_or_refine(state: AgentState) -> Literal["execute", "refine"]:
"""
Decide whether to execute or refine based on safety check.
Args:
state: Current agent state
Returns:
"execute" if safe, "refine" if unsafe
"""
if state["safety_status"] and state["safety_status"].is_safe:
return "execute"
else:
return "refine"
def should_continue_or_end(state: AgentState) -> Literal["continue", "end"]:
"""
Decide whether to continue to next step or end the workflow.
Args:
state: Current agent state
Returns:
"continue" if more steps remain, "end" if done or blocked
"""
# End if blocked
if state["safety_blocked"] and state["status"] == "blocked":
return "end"
# End if error occurred
if state["error"]:
return "end"
# End if all steps completed
if state["current_step_index"] >= state["plan"].total_steps:
return "end"
# Continue to next step
return "continue"
# ============================================================================
# GRAPH CONSTRUCTION
# ============================================================================
def create_reasoning_graph() -> StateGraph:
"""
Construct the full LangGraph state machine.
Returns:
Compiled StateGraph ready for execution
"""
# Create the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("planner", planner_node)
workflow.add_node("safety", safety_node)
workflow.add_node("executor", executor_node)
workflow.add_node("logger", logger_node)
workflow.add_node("justification", justification_node)
workflow.add_node("iterator", step_iterator_node)
workflow.add_node("refiner", refiner_node)
# Set entry point
workflow.set_entry_point("planner")
# Add edges
workflow.add_edge("planner", "safety")
# Conditional: safe -> execute, unsafe -> refine
workflow.add_conditional_edges(
"safety",
should_execute_or_refine,
{
"execute": "executor",
"refine": "refiner"
}
)
# After execution: log -> justify -> iterate
workflow.add_edge("executor", "logger")
workflow.add_edge("logger", "justification")
workflow.add_edge("justification", "iterator")
# After refining: go to iterator (to mark as done)
workflow.add_edge("refiner", "iterator")
# Conditional: continue to next step or end
workflow.add_conditional_edges(
"iterator",
should_continue_or_end,
{
"continue": "safety", # Loop back to safety check for next step
"end": END
}
)
# Compile the graph
return workflow.compile()
# ============================================================================
# CONVENIENCE FUNCTION
# ============================================================================
def run_reasoning_task(task: str, task_id: str, user_id: str = None) -> AgentState:
"""
Execute a reasoning task through the full pipeline.
Args:
task: The task to solve
task_id: Unique identifier for this execution
user_id: Optional user identifier
Returns:
Final agent state with results and logs
"""
from state import create_initial_state
# Create initial state
initial_state = create_initial_state(task, task_id, user_id)
# Create and run the graph
graph = create_reasoning_graph()
print(f"\n{'#'*60}")
print(f"🚀 STARTING REASONING PIPELINE")
print(f" Task: {task}")
print(f" Task ID: {task_id}")
print(f"{'#'*60}")
# Execute
final_state = graph.invoke(initial_state)
print(f"\n{'#'*60}")
print(f"🏁 REASONING PIPELINE COMPLETE")
print(f" Status: {final_state['status']}")
print(f" Steps Executed: {len(final_state['justifications'])}/{final_state['plan'].total_steps if final_state['plan'] else 0}")
print(f" Cryptographic Logs: {len(final_state['logs'])}")
print(f"{'#'*60}\n")
return final_state
# ============================================================================
# EXAMPLE USAGE
# ============================================================================
if __name__ == "__main__":
# Test the graph
result = run_reasoning_task(
task="Analyze the current state of AI safety research and provide 3 key findings",
task_id="test_001",
user_id="demo_user"
)
# Print results
print("\n=== FINAL RESULTS ===")
print(f"Status: {result['status']}")
print(f"\nJustifications:")
for j in result['justifications']:
print(f" Step {j.step_number}: {j.reasoning[:100]}...")
print(f"\nCryptographic Audit Trail:")
for log in result['logs']:
print(f" Step {log.step_number}: Hash {log.action_hash[:16]}... -> Root {log.merkle_root[:16]}...")