Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """FastAPI application for TemporalBenchEnv.""" | |
| import os | |
| from pathlib import Path | |
| try: | |
| from env.config import EnvConfig | |
| from env.temporal_bench_env import TemporalBenchEnvironment | |
| from models import TemporalBenchAction, TemporalBenchObservation | |
| except ImportError: | |
| from ..env.config import EnvConfig | |
| from ..env.temporal_bench_env import TemporalBenchEnvironment | |
| from ..models import TemporalBenchAction, TemporalBenchObservation | |
| try: | |
| from openenv.core.env_server import create_app | |
| except ImportError: | |
| create_app = None # type: ignore | |
| def _env_factory(): | |
| """Create a fresh environment instance per WebSocket session.""" | |
| bank_dir = os.environ.get("TEMPORALBENCH_QUESTION_BANK_DIR") | |
| if not bank_dir: | |
| default = Path(__file__).resolve().parents[1] / "tests" / "fixtures" / "banks" | |
| if default.is_dir(): | |
| bank_dir = str(default) | |
| cfg = EnvConfig(question_bank_path=bank_dir) if bank_dir else EnvConfig() | |
| return TemporalBenchEnvironment(config=cfg) | |
| if create_app is not None: | |
| app = create_app( | |
| _env_factory, | |
| TemporalBenchAction, | |
| TemporalBenchObservation, | |
| env_name="temporal-bench-env", | |
| max_concurrent_envs=64, | |
| ) | |
| else: | |
| from fastapi import FastAPI | |
| app = FastAPI(title="temporal-bench-env") | |
| app.get("/health")(lambda: {"status": "ok"}) | |
| def main(host: str | None = None, port: int | None = None) -> None: | |
| """ | |
| Entry point for `uv run server` and OpenEnv multi-mode validation. | |
| OpenEnv's validator does a naive substring check for ``main()`` in this | |
| file, so the ``if __name__ == "__main__"`` block must call ``main()`` with | |
| no arguments; CLI flags are parsed here via ``parse_known_args``. | |
| """ | |
| import argparse | |
| import uvicorn | |
| if host is None or port is None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=8000) | |
| ns, _ = parser.parse_known_args() | |
| if host is None: | |
| host = ns.host | |
| if port is None: | |
| port = ns.port | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| main() | |