minhan6559's picture
Upload 126 files
223ef32 verified
raw
history blame
10.7 kB
"""
React Agent for Cyber Knowledge Base
This script creates a ReAct agent using LangGraph that can use the CyberKnowledgeBase
search method as a tool to retrieve MITRE ATT&CK techniques.
"""
import os
import sys
import json
from typing import List, Dict, Any, Union, Optional
from pathlib import Path
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
from langchain_core.language_models.chat_models import BaseChatModel
# Import local modules
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
# Initialize the knowledge base
def init_knowledge_base(
persist_dir: str = "./cyber_knowledge_base",
) -> CyberKnowledgeBase:
"""Initialize and load the cyber knowledge base"""
kb = CyberKnowledgeBase()
# Try to load existing knowledge base
if kb.load_knowledge_base(persist_dir):
print("[SUCCESS] Loaded existing knowledge base")
return kb
else:
print("[WARNING] Could not load knowledge base, please build it first")
print("Run: python src/scripts/build_cyber_database.py")
sys.exit(1)
def _format_results_as_json(results) -> List[Dict[str, Any]]:
"""Format search results as structured JSON"""
output = []
for doc in results:
technique_info = {
"attack_id": doc.metadata.get("attack_id", "Unknown"),
"name": doc.metadata.get("name", "Unknown"),
"tactics": [
t.strip()
for t in doc.metadata.get("tactics", "").split(",")
if t.strip()
],
"platforms": [
p.strip()
for p in doc.metadata.get("platforms", "").split(",")
if p.strip()
],
"description": (
doc.page_content.split("Description: ")[-1]
if "Description: " in doc.page_content
else doc.page_content
),
"relevance_score": doc.metadata.get(
"relevance_score", None
), # From reranking
}
output.append(technique_info)
return output
def create_agent(llm_client: BaseChatModel, kb: CyberKnowledgeBase):
"""Create a ReAct agent with LangGraph"""
# Define the tools bound to the provided knowledge base
@tool
def search_techniques(
queries: Union[str, List[str]],
top_k: int = 5,
rerank_query: Optional[str] = None,
) -> str:
"""
Search for MITRE ATT&CK techniques using the knowledge base.
This tool searches a vector database containing MITRE ATT&CK technique descriptions,
including their tactics, platforms, and detailed behavioral information. Each technique
in the database has its full description embedded for semantic similarity search.
Args:
queries: Single search query string OR list of query strings.
rerank_query: Optional tag echoed in the output for transparency.
top_k: Number of results to return per query (default: 10)
Returns:
JSON string with results grouped per query. Each group contains:
- query: The original query string
- techniques: List of technique objects (attack_id, name, tactics, platforms, description, relevance_score)
- total_results: Number of techniques in this group
"""
try:
# Convert single query to list for uniform processing
if isinstance(queries, str):
queries = [queries]
# Run a normal search once per query and keep results associated with that query
results_by_query: List[Dict[str, Any]] = []
for i, q in enumerate(queries, 1):
print(f"[INFO] Query {i}/{len(queries)}: '{q}'")
per_query_results = kb.search(q, top_k=top_k)
techniques = _format_results_as_json(per_query_results)
results_by_query.append(
{
"query": q,
"techniques": techniques,
"total_results": len(techniques),
}
)
# If all queries returned no results
if all(len(group["techniques"]) == 0 for group in results_by_query):
return json.dumps(
{
"results_by_query": results_by_query,
"message": "No techniques found matching the provided queries.",
},
indent=2,
)
return json.dumps(
{
"results_by_query": results_by_query,
"queries_used": queries,
"rerank_query": rerank_query,
},
indent=2,
)
except Exception as e:
return json.dumps(
{
"error": str(e),
"techniques": [],
"message": "Error occurred during search",
},
indent=2,
)
tools = [search_techniques]
# Define the system prompt for the agent
system_prompt = """
You are a cybersecurity analyst assistant that helps answer questions about MITRE ATT&CK techniques.
You have access to a knowledge base of MITRE ATT&CK techniques that you can search.
Use the search_techniques tool to find relevant techniques based on the user's query.
"""
# Get the LLM from the client
llm = llm_client
# Create the React agent
agent_runnable = create_react_agent(llm, tools, prompt=system_prompt)
return agent_runnable
def run_test_queries(agent):
"""Run the agent with some test queries"""
# Test queries
test_queries = [
"What techniques are used for credential dumping?",
"How do attackers use process injection for defense evasion?",
"What are common persistence techniques on Windows systems?",
]
# Run the agent with test queries
for i, query in enumerate(test_queries, 1):
print(f"\n\n===== Test Query {i}: '{query}' =====\n")
# Create the input state
state = {"messages": [HumanMessage(content=query)]}
# Run the agent
result = agent.invoke(state)
# Print all intermediate messages
print("[TRACE] Conversation messages:")
for message in result["messages"]:
if isinstance(message, HumanMessage):
print(f"- [Human] {message.content}")
elif isinstance(message, AIMessage):
agent_name = getattr(message, "name", None) or "agent"
print(f"- [Agent:{agent_name}] {message.content}")
if "function_call" in message.additional_kwargs:
fc = message.additional_kwargs["function_call"]
print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}")
elif isinstance(message, ToolMessage):
tool_name = getattr(message, "name", None) or "tool"
print(f"- [Tool:{tool_name}] {message.content}")
def interactive_mode(agent):
"""Run the agent in interactive mode"""
print("\n\n===== Interactive Mode =====")
print("Type 'exit' or 'quit' to end the session\n")
# Keep track of conversation history
messages = []
while True:
# Get user input
user_input = input("\nYou: ")
# Check if user wants to exit
if user_input.lower() in ["exit", "quit"]:
print("Exiting interactive mode...")
break
# Add user message to history
messages.append(HumanMessage(content=user_input))
# Create the input state
state = {"messages": messages.copy()}
# Run the agent
try:
result = agent.invoke(state)
# Update conversation history with agent's response
messages = result["messages"]
# Print the agent's response
for message in messages:
if isinstance(message, AIMessage):
print("\n" + "=" * 50)
print(f"\nAgent: {message.content}")
if "function_call" in message.additional_kwargs:
print(
"Function call:",
message.additional_kwargs["function_call"]["name"],
)
print(
"Arguments:",
message.additional_kwargs["function_call"]["arguments"],
)
print("-" * 50)
if isinstance(message, ToolMessage):
print("Tool output:", message.content)
except Exception as e:
print(f"Error: {str(e)}")
def main():
"""Main function to run the agent"""
global kb
# Initialize the knowledge base
kb_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"cyber_knowledge_base",
)
kb = init_knowledge_base(kb_path)
# Print KB stats
stats = kb.get_stats()
print(
f"Knowledge base loaded with {stats.get('total_techniques', 'unknown')} techniques"
)
# Initialize the LLM client (using environment variables)
llm_client = init_chat_model("google_genai:gemini-2.0-flash", temperature=0.2)
# Create the agent
agent = create_agent(llm_client, kb)
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description="Run the Cyber KB React Agent")
parser.add_argument(
"--interactive", "-i", action="store_true", help="Run in interactive mode"
)
parser.add_argument("--test", "-t", action="store_true", help="Run test queries")
args = parser.parse_args()
# Run in the appropriate mode
if args.interactive:
interactive_mode(agent)
elif args.test:
run_test_queries(agent)
else:
# Default: run interactive mode
interactive_mode(agent)
if __name__ == "__main__":
main()