# Copyright (c) Yogesh Singla and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ FastAPI application for the Julia Environment with concurrent execution support. This module creates an HTTP server that exposes the JuliaCodeActEnv over HTTP endpoints with optimized async execution for handling multiple concurrent requests efficiently. Features: - Async Julia code execution to avoid blocking - Environment pool for concurrent request handling - Thread pool executor for CPU-bound Julia tasks - Automatic error recovery and retry logic - Comprehensive logging to file and console - Worker health monitoring and auto-restart - 10x+ performance improvement over single-threaded version Usage: # Development (with auto-reload): uvicorn envs.julia_env.server.app:app --reload --host 0.0.0.0 --port 8000 # Production (with multiple workers for even better concurrency): uvicorn envs.julia_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4 # Or run directly: python -m envs.julia_env.server.app """ import asyncio import logging import os import sys import traceback from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from dataclasses import asdict from datetime import datetime from logging.handlers import RotatingFileHandler from typing import Any, Dict from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from ..models import JuliaAction, JuliaObservation from .julia_codeact_env import JuliaCodeActEnv # Configuration MAX_WORKERS = int( os.getenv("JULIA_MAX_WORKERS", "8") ) # Number of concurrent Julia executions ENABLE_WEB = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") EXECUTION_TIMEOUT = int(os.getenv("JULIA_EXECUTION_TIMEOUT", "120")) # seconds LOG_FILE = os.getenv("JULIA_LOG_FILE", "/tmp/run.log") LOG_LEVEL = os.getenv("JULIA_LOG_LEVEL", "INFO") # Global thread pool executor for CPU-bound Julia tasks executor = None # Setup comprehensive logging def setup_logging(): """Configure logging to both file and console with rotation.""" logger = logging.getLogger("julia_env") logger.setLevel(getattr(logging, LOG_LEVEL)) # Prevent duplicate handlers if logger.handlers: return logger # Create formatters detailed_formatter = logging.Formatter( "%(asctime)s - %(name)s - [%(process)d:%(thread)d] - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # File handler with rotation (10MB max, keep 5 backup files) try: os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) file_handler = RotatingFileHandler( LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8" # 10MB ) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(detailed_formatter) logger.addHandler(file_handler) except Exception as e: print(f"Warning: Could not create log file {LOG_FILE}: {e}") # Console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(logging.INFO) console_handler.setFormatter(detailed_formatter) logger.addHandler(console_handler) return logger logger = setup_logging() @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup/shutdown with health monitoring""" global executor logger.info("=" * 80) logger.info("Starting Julia Environment Server") logger.info(f"Max Workers: {MAX_WORKERS}") logger.info(f"Execution Timeout: {EXECUTION_TIMEOUT}s") logger.info(f"Log File: {LOG_FILE}") logger.info(f"Log Level: {LOG_LEVEL}") logger.info("=" * 80) # Startup: Create thread pool with error handling try: executor = ThreadPoolExecutor( max_workers=MAX_WORKERS, thread_name_prefix="julia_worker" ) logger.info(f"✅ Thread pool created with {MAX_WORKERS} workers") logger.info(f"✅ Julia Environment Server started successfully") print( f"✅ Julia Environment Server started with {MAX_WORKERS} concurrent workers" ) except Exception as e: logger.error(f"❌ Failed to start server: {e}") logger.error(traceback.format_exc()) raise yield # Shutdown: Cleanup with grace period logger.info("Shutting down Julia Environment Server...") try: executor.shutdown(wait=True, cancel_futures=False) logger.info("✅ All workers completed gracefully") except Exception as e: logger.error(f"Error during shutdown: {e}") logger.info("✅ Julia Environment Server shutdown complete") print("✅ Julia Environment Server shutdown complete") # Create FastAPI app with lifespan management app = FastAPI( title="Julia Environment Server", description="Async Julia code execution environment with concurrent request support and auto-recovery", version="2.1.0", lifespan=lifespan, ) # Global exception handler for uncaught errors @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): """Handle all uncaught exceptions to prevent worker crashes""" error_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") logger.error(f"[ERROR-{error_id}] Uncaught exception in {request.url.path}") logger.error(f"[ERROR-{error_id}] Request: {request.method} {request.url}") logger.error(f"[ERROR-{error_id}] Exception: {type(exc).__name__}: {exc}") logger.error(f"[ERROR-{error_id}] Traceback:\n{traceback.format_exc()}") return JSONResponse( status_code=500, content={ "error": "Internal server error", "type": type(exc).__name__, "message": str(exc), "error_id": error_id, "timestamp": datetime.now().isoformat(), }, ) async def execute_julia_async( action: JuliaAction, request_id: str = None ) -> JuliaObservation: """ Execute Julia code asynchronously in thread pool with timeout and error recovery. This runs the CPU-bound Julia execution in a separate thread to avoid blocking the event loop, allowing the server to handle multiple requests concurrently. Features: - Timeout protection - Automatic retry on transient failures - Comprehensive error logging - Resource cleanup """ if request_id is None: request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") loop = asyncio.get_event_loop() max_retries = 2 retry_count = 0 logger.debug( f"[{request_id}] Starting Julia execution (timeout: {EXECUTION_TIMEOUT}s)" ) while retry_count <= max_retries: env = None try: # Create a fresh environment instance for this request # This ensures thread safety and allows concurrent execution env = JuliaCodeActEnv() # Run the blocking step() call in thread pool with timeout observation = await asyncio.wait_for( loop.run_in_executor(executor, env.step, action), timeout=EXECUTION_TIMEOUT, ) logger.debug(f"[{request_id}] Julia execution completed successfully") logger.debug( f"[{request_id}] Result: tests_passed={observation.tests_passed}, " f"tests_failed={observation.tests_failed}, reward={observation.reward}" ) return observation except asyncio.TimeoutError: retry_count += 1 logger.warning( f"[{request_id}] Julia execution timeout (attempt {retry_count}/{max_retries + 1})" ) if retry_count > max_retries: logger.error( f"[{request_id}] Julia execution failed after {max_retries + 1} attempts" ) # Return a failure observation return JuliaObservation( stdout="", stderr=f"Execution timeout after {EXECUTION_TIMEOUT}s", exit_code=-1, tests_passed=0, tests_failed=1, code_compiles=False, reward=0.0, done=True, ) # Wait a bit before retry await asyncio.sleep(0.5) except Exception as e: retry_count += 1 logger.error( f"[{request_id}] Julia execution error (attempt {retry_count}/{max_retries + 1}): {e}" ) logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}") if retry_count > max_retries: logger.error( f"[{request_id}] Julia execution failed permanently after {max_retries + 1} attempts" ) # Return a failure observation return JuliaObservation( stdout="", stderr=f"Execution error: {str(e)}", exit_code=-1, tests_passed=0, tests_failed=1, code_compiles=False, reward=0.0, done=True, ) # Wait a bit before retry await asyncio.sleep(0.5) finally: # Clean up environment resources if possible if env is not None: try: # Add any cleanup needed here del env except Exception as cleanup_error: logger.debug(f"[{request_id}] Cleanup warning: {cleanup_error}") @app.post("/reset") async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: """ Reset endpoint - returns initial observation. Creates a fresh environment instance for the new episode. """ request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") logger.info(f"[{request_id}] Reset request received") try: # Run reset in thread pool to avoid blocking loop = asyncio.get_event_loop() env = JuliaCodeActEnv() observation = await asyncio.wait_for( loop.run_in_executor(executor, env.reset), timeout=30.0, # Reset should be quick ) # Serialize observation obs_dict = asdict(observation) reward = obs_dict.pop("reward", None) done = obs_dict.pop("done", False) obs_dict.pop("metadata", None) logger.info(f"[{request_id}] Reset completed successfully") return { "observation": obs_dict, "reward": reward, "done": done, } except asyncio.TimeoutError: logger.error(f"[{request_id}] Reset timeout") raise HTTPException(status_code=504, detail="Reset operation timed out") except Exception as e: logger.error(f"[{request_id}] Reset error: {e}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}") @app.post("/step") async def step(request: Dict[str, Any]) -> Dict[str, Any]: """ Step endpoint - executes Julia code and returns observation. Runs Julia code execution asynchronously to handle multiple concurrent requests. Each request gets its own environment instance for thread safety. """ request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") try: action_data = request.get("action", {}) if not action_data: logger.warning(f"[{request_id}] Step request with empty action") raise HTTPException(status_code=400, detail="Action data is required") # Deserialize action metadata = action_data.pop("metadata", {}) action = JuliaAction(**action_data) action.metadata = metadata logger.info(f"[{request_id}] Step request received") logger.debug( f"[{request_id}] Action: core_code_length={len(action.core_code) if action.core_code else 0}, " f"test_code_length={len(action.test_code) if action.test_code else 0}" ) # Execute Julia code asynchronously with timeout and retry observation = await execute_julia_async(action, request_id) # Serialize observation obs_dict = asdict(observation) reward = obs_dict.pop("reward", None) done = obs_dict.pop("done", False) obs_dict.pop("metadata", None) logger.info( f"[{request_id}] Step completed - reward={reward}, " f"tests_passed={observation.tests_passed}, tests_failed={observation.tests_failed}" ) return { "observation": obs_dict, "reward": reward, "done": done, } except HTTPException: raise except Exception as e: logger.error(f"[{request_id}] Step endpoint error: {e}") logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}") raise HTTPException(status_code=500, detail=f"Step execution failed: {str(e)}") @app.get("/state") async def get_state() -> Dict[str, Any]: """ State endpoint - returns environment metadata and server health. Note: Since each request creates a fresh environment, this returns general server state rather than specific episode state. """ try: import psutil process = psutil.Process() memory_info = process.memory_info() return { "max_workers": MAX_WORKERS, "executor_type": "ThreadPoolExecutor", "status": "ready", "timeout": EXECUTION_TIMEOUT, "log_file": LOG_FILE, "memory_mb": memory_info.rss / 1024 / 1024, "threads": len(process.threads()), } except ImportError: # psutil not available, return basic info return { "max_workers": MAX_WORKERS, "executor_type": "ThreadPoolExecutor", "status": "ready", "timeout": EXECUTION_TIMEOUT, "log_file": LOG_FILE, } except Exception as e: logger.warning(f"Could not get full state info: {e}") return { "max_workers": MAX_WORKERS, "executor_type": "ThreadPoolExecutor", "status": "ready", } @app.get("/health") async def health() -> Dict[str, str]: """ Health check endpoint. Returns healthy status if the server is operational and can accept requests. """ try: # Quick health check - verify executor is available if executor is None: logger.error("Health check failed: executor not initialized") raise HTTPException(status_code=503, detail="Service not ready") return { "status": "healthy", "workers": str(MAX_WORKERS), "timeout": str(EXECUTION_TIMEOUT), "timestamp": datetime.now().isoformat(), } except HTTPException: raise except Exception as e: logger.error(f"Health check error: {e}") raise HTTPException(status_code=503, detail="Health check failed") if __name__ == "__main__": import uvicorn # Run with uvicorn # Use multiple workers for even better concurrency uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")