vapt-agent / vapt_mcp_client.py
humanizetech's picture
feat: Add VAPT agent client for MCP security tests
e671617
import asyncio
import ast
import os
from pathlib import Path
from typing import Dict, Any
from dotenv import load_dotenv
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
# Load .env automatically
load_dotenv()
MCP_SERVER_URL = "https://mcp-1st-birthday-vapt-agent.hf.space/gradio_api/mcp/"
TOOL_NAME = "vapt_agent_run_security_test"
BASE_DIR = Path(__file__).parent
async def call_vapt_security_test(
api_endpoint: str,
http_method: str = "GET",
api_key: str = "",
) -> Dict[str, Any]:
report_md_text = None
report_file_name = None
async with streamablehttp_client(url=MCP_SERVER_URL, headers={}) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tools_resp = await session.list_tools()
tool_names = [t.name for t in tools_resp.tools]
if TOOL_NAME not in tool_names:
raise RuntimeError(
f"Tool {TOOL_NAME!r} not found on server. Available: {tool_names}"
)
args = {
"api_endpoint": api_endpoint,
"http_method": http_method,
"api_key": api_key,
}
result = await session.call_tool(TOOL_NAME, args)
def handle_text_chunk(text: str):
nonlocal report_md_text, report_file_name
parsed = None
try:
parsed = ast.literal_eval(text)
except Exception:
parsed = None
# Expected: ['progress...', 'report_md', file_info_dict]
if (
isinstance(parsed, list)
and len(parsed) >= 3
and isinstance(parsed[0], str)
and isinstance(parsed[1], str)
and isinstance(parsed[2], dict)
):
progress_str, md_str, file_info = parsed[0], parsed[1], parsed[2]
print("\n=== Progress ===")
for line in progress_str.splitlines():
print(line)
print("\n=== Report (first 20 lines) ===")
md_lines = md_str.splitlines()
for line in md_lines[:20]:
print(line)
if len(md_lines) > 20:
print("... [truncated, full report saved to .md file]")
report_md_text = md_str
report_file_name = file_info.get("orig_name", "vapt_report.md")
else:
print(text)
# Future compatibility: streaming async iterator
if hasattr(result, "__aiter__"):
async for event in result:
for block in event.content:
if getattr(block, "type", None) == "text":
handle_text_chunk(block.text)
else:
print("\n=== RAW TOOL RESULT METADATA ===")
print(f"isError: {result.isError}")
for block in result.content:
if getattr(block, "type", None) == "text":
handle_text_chunk(block.text)
# Save markdown
if report_md_text:
out_name = report_file_name or "vapt_report.md"
out_path = BASE_DIR / out_name
out_path.write_text(report_md_text, encoding="utf-8")
print(f"\n✅ Markdown report saved to: {out_path.resolve()}")
return {
"is_error": getattr(result, "isError", None),
"content": getattr(result, "content", None),
}
async def main():
api_endpoint = os.getenv("TEST_API_ENDPOINT")
http_method = os.getenv("TEST_API_METHOD", "GET")
api_key = os.getenv("TEST_API_KEY", "")
print("\nUsing environment settings:")
print(f" TEST_API_ENDPOINT = {api_endpoint}")
print(f" TEST_API_METHOD = {http_method}")
# Do NOT print API key
await call_vapt_security_test(
api_endpoint=api_endpoint,
http_method=http_method,
api_key=api_key,
)
if __name__ == "__main__":
asyncio.run(main())