File size: 3,145 Bytes
4729fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8177906
4729fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python3
"""
Test script for benchmark environment concurrency.

Run the server first:
    cd benchmark && uvicorn server.app:app --reload --port 8000

Then run this script:
    python test_concurrency.py --requests 10 --wait 1.0
"""

import argparse
import asyncio
import time

import httpx


BASE_URL = "https://burtenshaw-openenv-benchmark.hf.space/"


async def reset(client: httpx.AsyncClient) -> dict:
    """Reset the environment and return observation."""
    response = await client.post(f"{BASE_URL}/reset")
    response.raise_for_status()
    return response.json()


async def step(client: httpx.AsyncClient, wait_seconds: float) -> dict:
    """Execute a step with the given wait time."""
    response = await client.post(
        f"{BASE_URL}/step",
        json={"action": {"wait_seconds": wait_seconds}},
    )
    response.raise_for_status()
    return response.json()


async def timed_request(client: httpx.AsyncClient, wait_seconds: float, request_id: int) -> dict:
    """Make a timed request and return results with timing info."""
    start = time.perf_counter()
    result = await step(client, wait_seconds)
    elapsed = time.perf_counter() - start

    obs = result["observation"]
    return {
        "request_id": request_id,
        "wait_requested": wait_seconds,
        "elapsed": elapsed,
        "pid": obs["pid"],
        "session_hash": obs["session_hash"],
    }


async def test_concurrent(num_requests: int, wait_seconds: float) -> dict:
    """Test concurrent requests and return timing stats."""
    async with httpx.AsyncClient(timeout=60.0) as client:
        # Reset first
        reset_result = await reset(client)
        obs = reset_result["observation"]
        print(f"Server: {obs['host_url']} | PID: {obs['pid']} | Session: {obs['session_hash']}")
        print(f"Running {num_requests} concurrent requests, each waiting {wait_seconds}s...")

        start = time.perf_counter()

        # Launch all requests concurrently
        tasks = [timed_request(client, wait_seconds, i) for i in range(num_requests)]
        results = await asyncio.gather(*tasks)

        total_time = time.perf_counter() - start
        avg_time = sum(r["elapsed"] for r in results) / len(results)

        return {
            "num_requests": num_requests,
            "wait_seconds": wait_seconds,
            "total_time": total_time,
            "avg_time": avg_time,
        }


async def main():
    parser = argparse.ArgumentParser(description="Test benchmark environment concurrency")
    parser.add_argument("--requests", "-n", type=int, default=10, help="Number of concurrent requests")
    parser.add_argument("--wait", "-w", type=float, default=1.0, help="Wait time per request (seconds)")
    parser.add_argument("--url", "-u", type=str, default="http://localhost:8000", help="Server URL")
    args = parser.parse_args()

    global BASE_URL
    BASE_URL = args.url

    result = await test_concurrent(args.requests, args.wait)

    print(f"\nTotal time: {result['total_time']:.3f}s")
    print(f"Avg time:   {result['avg_time']:.3f}s")


if __name__ == "__main__":
    asyncio.run(main())