sql_env / server /app.py
hjerpe's picture
Upload folder using huggingface_hub
a001a97 verified
Raw
History Blame Contribute Delete
3.43 kB
"""
FastAPI application for the SQLEnv environment.
Exposes the SQLEnvironment over HTTP and WebSocket endpoints,
compatible with the OpenEnv EnvClient.
Usage:
# Development (with auto-reload):
uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
# Via uv:
uv run server
"""
import os
from pathlib import Path
# Load environment variables from .env file
try:
from dotenv import load_dotenv
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
load_dotenv(env_file)
except ImportError:
pass # python-dotenv not installed, use system env vars
from openenv.core.env_server import create_app
try:
from sql_env.models import SQLAction, SQLObservation
from sql_env.server.sql_environment import SQLEnvironment
except ImportError:
# Fallback for Docker where PYTHONPATH=/app/env
from models import SQLAction, SQLObservation # type: ignore[no-redef]
from server.sql_environment import SQLEnvironment # type: ignore[no-redef]
def get_tokenizer():
"""Get tokenizer from environment or use a mock for testing."""
tokenizer_name = os.environ.get(
"TOKENIZER_NAME", "mistralai/Mistral-7B-Instruct-v0.1"
)
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
print(f"Loaded tokenizer: {tokenizer_name}")
return tokenizer
except ImportError:
print(
"Warning: transformers not installed, using mock tokenizer for testing only"
)
from server.mock_tokenizer import MockTokenizer
return MockTokenizer()
def create_sql_environment():
"""Factory function that creates SQLEnvironment with tokenizer and paths."""
tokenizer = get_tokenizer()
questions_path = os.environ.get(
"QUESTIONS_PATH",
str(
Path(__file__).parent.parent
/ "data"
/ "questions"
/ "student_assessment.json"
),
)
db_dir = os.environ.get(
"DB_DIR",
str(Path(__file__).parent.parent / "data" / "databases"),
)
return SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=tokenizer,
)
# Create the FastAPI app.
#
# Note: hosted Space is single-session. External users running TRL's
# GRPOTrainer against https://hjerpe-sql-env.hf.space with
# num_generations > 1 will hit openenv-core's default 1-session cap.
# Fix requires (a) auditing SQLEnvironment for shared mutable state
# across sessions, (b) declaring SUPPORTS_CONCURRENT_SESSIONS=True on
# the class, (c) passing max_concurrent_envs=64 here. Deferred as a
# post-launch follow-up. Our own training uses an in-process
# SQLEnvironment via SQLEnvTRL, so this does not affect internal runs.
app = create_app(
create_sql_environment,
SQLAction,
SQLObservation,
env_name="sql_env",
)
def main(host: str = "0.0.0.0", port: int | None = None):
"""Entry point for running the server directly.
Enables:
uv run server
python -m sql_env.server.app
"""
import uvicorn
if port is None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
port = args.port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()