| """FastAPI application for SupplyChainEnv.""" |
|
|
| from __future__ import annotations |
| import inspect |
| import os |
| import sys |
| from pathlib import Path |
|
|
| SERVER_DIR = Path(__file__).resolve().parent |
| REPO_ROOT = SERVER_DIR.parent |
| for p in [str(SERVER_DIR), str(REPO_ROOT)]: |
| if p not in sys.path: |
| sys.path.insert(0, p) |
|
|
| from openenv.core.env_server.http_server import create_app |
|
|
| try: |
| from ..models import SupplyChainAction, SupplyChainObservation |
| from .supply_chain_environment import SupplyChainEnvironment |
| except ImportError: |
| from models import SupplyChainAction, SupplyChainObservation |
| from server.supply_chain_environment import SupplyChainEnvironment |
|
|
|
|
| def create_supply_chain_environment(): |
| return SupplyChainEnvironment() |
|
|
|
|
| app = create_app( |
| create_supply_chain_environment, |
| SupplyChainAction, |
| SupplyChainObservation, |
| env_name="supply_chain_env", |
| max_concurrent_envs=50, |
| ) |
|
|
|
|
| def main(host="0.0.0.0", port=None): |
| import uvicorn |
| if port is None: |
| port = int(os.getenv("API_PORT", "7860")) |
| uvicorn.run(app, host=host, port=port) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|