Spaces:
Running
Running
File size: 4,360 Bytes
e671617 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import asyncio
import ast
import os
from pathlib import Path
from typing import Dict, Any
from dotenv import load_dotenv
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
# Load .env automatically
load_dotenv()
MCP_SERVER_URL = "https://mcp-1st-birthday-vapt-agent.hf.space/gradio_api/mcp/"
TOOL_NAME = "vapt_agent_run_security_test"
BASE_DIR = Path(__file__).parent
async def call_vapt_security_test(
api_endpoint: str,
http_method: str = "GET",
api_key: str = "",
) -> Dict[str, Any]:
report_md_text = None
report_file_name = None
async with streamablehttp_client(url=MCP_SERVER_URL, headers={}) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tools_resp = await session.list_tools()
tool_names = [t.name for t in tools_resp.tools]
if TOOL_NAME not in tool_names:
raise RuntimeError(
f"Tool {TOOL_NAME!r} not found on server. Available: {tool_names}"
)
args = {
"api_endpoint": api_endpoint,
"http_method": http_method,
"api_key": api_key,
}
result = await session.call_tool(TOOL_NAME, args)
def handle_text_chunk(text: str):
nonlocal report_md_text, report_file_name
parsed = None
try:
parsed = ast.literal_eval(text)
except Exception:
parsed = None
# Expected: ['progress...', 'report_md', file_info_dict]
if (
isinstance(parsed, list)
and len(parsed) >= 3
and isinstance(parsed[0], str)
and isinstance(parsed[1], str)
and isinstance(parsed[2], dict)
):
progress_str, md_str, file_info = parsed[0], parsed[1], parsed[2]
print("\n=== Progress ===")
for line in progress_str.splitlines():
print(line)
print("\n=== Report (first 20 lines) ===")
md_lines = md_str.splitlines()
for line in md_lines[:20]:
print(line)
if len(md_lines) > 20:
print("... [truncated, full report saved to .md file]")
report_md_text = md_str
report_file_name = file_info.get("orig_name", "vapt_report.md")
else:
print(text)
# Future compatibility: streaming async iterator
if hasattr(result, "__aiter__"):
async for event in result:
for block in event.content:
if getattr(block, "type", None) == "text":
handle_text_chunk(block.text)
else:
print("\n=== RAW TOOL RESULT METADATA ===")
print(f"isError: {result.isError}")
for block in result.content:
if getattr(block, "type", None) == "text":
handle_text_chunk(block.text)
# Save markdown
if report_md_text:
out_name = report_file_name or "vapt_report.md"
out_path = BASE_DIR / out_name
out_path.write_text(report_md_text, encoding="utf-8")
print(f"\n✅ Markdown report saved to: {out_path.resolve()}")
return {
"is_error": getattr(result, "isError", None),
"content": getattr(result, "content", None),
}
async def main():
api_endpoint = os.getenv("TEST_API_ENDPOINT")
http_method = os.getenv("TEST_API_METHOD", "GET")
api_key = os.getenv("TEST_API_KEY", "")
print("\nUsing environment settings:")
print(f" TEST_API_ENDPOINT = {api_endpoint}")
print(f" TEST_API_METHOD = {http_method}")
# Do NOT print API key
await call_vapt_security_test(
api_endpoint=api_endpoint,
http_method=http_method,
api_key=api_key,
)
if __name__ == "__main__":
asyncio.run(main())
|