File size: 2,445 Bytes
d954568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""FastAPI application for TemporalBenchEnv."""

import os
from pathlib import Path

try:
    from env.config import EnvConfig
    from env.temporal_bench_env import TemporalBenchEnvironment
    from models import TemporalBenchAction, TemporalBenchObservation
except ImportError:
    from ..env.config import EnvConfig
    from ..env.temporal_bench_env import TemporalBenchEnvironment
    from ..models import TemporalBenchAction, TemporalBenchObservation

try:
    from openenv.core.env_server import create_app
except ImportError:
    create_app = None  # type: ignore


def _env_factory():
    """Create a fresh environment instance per WebSocket session."""
    bank_dir = os.environ.get("TEMPORALBENCH_QUESTION_BANK_DIR")
    if not bank_dir:
        default = Path(__file__).resolve().parents[1] / "tests" / "fixtures" / "banks"
        if default.is_dir():
            bank_dir = str(default)
    cfg = EnvConfig(question_bank_path=bank_dir) if bank_dir else EnvConfig()
    return TemporalBenchEnvironment(config=cfg)


if create_app is not None:
    app = create_app(
        _env_factory,
        TemporalBenchAction,
        TemporalBenchObservation,
        env_name="temporal-bench-env",
        max_concurrent_envs=64,
    )
else:
    from fastapi import FastAPI

    app = FastAPI(title="temporal-bench-env")
    app.get("/health")(lambda: {"status": "ok"})


def main(host: str | None = None, port: int | None = None) -> None:
    """
    Entry point for `uv run server` and OpenEnv multi-mode validation.

    OpenEnv's validator does a naive substring check for ``main()`` in this
    file, so the ``if __name__ == "__main__"`` block must call ``main()`` with
    no arguments; CLI flags are parsed here via ``parse_known_args``.
    """
    import argparse

    import uvicorn

    if host is None or port is None:
        parser = argparse.ArgumentParser()
        parser.add_argument("--host", type=str, default="0.0.0.0")
        parser.add_argument("--port", type=int, default=8000)
        ns, _ = parser.parse_known_args()
        if host is None:
            host = ns.host
        if port is None:
            port = ns.port

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()