yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""FastAPI application for TemporalBenchEnv."""
import os
from pathlib import Path
try:
from env.config import EnvConfig
from env.temporal_bench_env import TemporalBenchEnvironment
from models import TemporalBenchAction, TemporalBenchObservation
except ImportError:
from ..env.config import EnvConfig
from ..env.temporal_bench_env import TemporalBenchEnvironment
from ..models import TemporalBenchAction, TemporalBenchObservation
try:
from openenv.core.env_server import create_app
except ImportError:
create_app = None # type: ignore
def _env_factory():
"""Create a fresh environment instance per WebSocket session."""
bank_dir = os.environ.get("TEMPORALBENCH_QUESTION_BANK_DIR")
if not bank_dir:
default = Path(__file__).resolve().parents[1] / "tests" / "fixtures" / "banks"
if default.is_dir():
bank_dir = str(default)
cfg = EnvConfig(question_bank_path=bank_dir) if bank_dir else EnvConfig()
return TemporalBenchEnvironment(config=cfg)
if create_app is not None:
app = create_app(
_env_factory,
TemporalBenchAction,
TemporalBenchObservation,
env_name="temporal-bench-env",
max_concurrent_envs=64,
)
else:
from fastapi import FastAPI
app = FastAPI(title="temporal-bench-env")
app.get("/health")(lambda: {"status": "ok"})
def main(host: str | None = None, port: int | None = None) -> None:
"""
Entry point for `uv run server` and OpenEnv multi-mode validation.
OpenEnv's validator does a naive substring check for ``main()`` in this
file, so the ``if __name__ == "__main__"`` block must call ``main()`` with
no arguments; CLI flags are parsed here via ``parse_known_args``.
"""
import argparse
import uvicorn
if host is None or port is None:
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
ns, _ = parser.parse_known_args()
if host is None:
host = ns.host
if port is None:
port = ns.port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()