File size: 2,155 Bytes
8dc7642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations

import subprocess
import sys
import time
from contextlib import contextmanager
from typing import Generator

import requests

from freeciv_env.client import FreecivEnv
from freeciv_env.models import FreecivAction


@contextmanager
def run_server(module_path: str, port: int) -> Generator[str, None, None]:
    process = subprocess.Popen(
        [
            sys.executable,
            "-m",
            "uvicorn",
            f"{module_path}:app",
            "--host",
            "127.0.0.1",
            "--port",
            str(port),
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    base_url = f"http://127.0.0.1:{port}"
    try:
        deadline = time.time() + 10
        while time.time() < deadline:
            try:
                if requests.get(f"{base_url}/health", timeout=1).status_code == 200:
                    break
            except requests.exceptions.ConnectionError:
                time.sleep(0.1)
        else:
            stderr = process.stderr.read().decode() if process.stderr else ""
            raise TimeoutError(stderr)
        yield base_url
    finally:
        process.terminate()
        try:
            process.wait(timeout=5)
        except subprocess.TimeoutExpired:
            process.kill()
            process.wait()


def test_websocket_roundtrip_with_fake_backend() -> None:
    with run_server("tests.fake_server", port=8130) as base_url:
        client = FreecivEnv(base_url=base_url)
        client.connect()
        try:
            reset_result = client.reset()
            assert reset_result.observation.turn == 1
            assert any(action.action_type == "move_unit" for action in reset_result.observation.legal_actions)

            step_result = client.step(FreecivAction(action_type="move_unit", unit_id=201, direction=0))
            assert step_result.observation.turn == 1
            assert step_result.observation.known_tiles == 5

            state = client.state()
            assert state.turn == 1
            assert state.score == step_result.observation.score
        finally:
            client.close()