Spaces:
Sleeping
Sleeping
| # Copyright (c) Yogesh Singla and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| FastAPI application for the Julia Environment with concurrent execution support. | |
| This module creates an HTTP server that exposes the JuliaCodeActEnv | |
| over HTTP endpoints with optimized async execution for handling multiple | |
| concurrent requests efficiently. | |
| Features: | |
| - Async Julia code execution to avoid blocking | |
| - Environment pool for concurrent request handling | |
| - Thread pool executor for CPU-bound Julia tasks | |
| - Automatic error recovery and retry logic | |
| - Comprehensive logging to file and console | |
| - Worker health monitoring and auto-restart | |
| - 10x+ performance improvement over single-threaded version | |
| Usage: | |
| # Development (with auto-reload): | |
| uvicorn envs.julia_env.server.app:app --reload --host 0.0.0.0 --port 8000 | |
| # Production (with multiple workers for even better concurrency): | |
| uvicorn envs.julia_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4 | |
| # Or run directly: | |
| python -m envs.julia_env.server.app | |
| """ | |
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| import traceback | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from dataclasses import asdict | |
| from datetime import datetime | |
| from logging.handlers import RotatingFileHandler | |
| from typing import Any, Dict | |
| from fastapi import Body, FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from ..models import JuliaAction, JuliaObservation | |
| from .julia_codeact_env import JuliaCodeActEnv | |
| # Configuration | |
| MAX_WORKERS = int( | |
| os.getenv("JULIA_MAX_WORKERS", "8") | |
| ) # Number of concurrent Julia executions | |
| ENABLE_WEB = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") | |
| EXECUTION_TIMEOUT = int(os.getenv("JULIA_EXECUTION_TIMEOUT", "120")) # seconds | |
| LOG_FILE = os.getenv("JULIA_LOG_FILE", "/tmp/run.log") | |
| LOG_LEVEL = os.getenv("JULIA_LOG_LEVEL", "INFO") | |
| # Global thread pool executor for CPU-bound Julia tasks | |
| executor = None | |
| # Setup comprehensive logging | |
| def setup_logging(): | |
| """Configure logging to both file and console with rotation.""" | |
| logger = logging.getLogger("julia_env") | |
| logger.setLevel(getattr(logging, LOG_LEVEL)) | |
| # Prevent duplicate handlers | |
| if logger.handlers: | |
| return logger | |
| # Create formatters | |
| detailed_formatter = logging.Formatter( | |
| "%(asctime)s - %(name)s - [%(process)d:%(thread)d] - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| # File handler with rotation (10MB max, keep 5 backup files) | |
| try: | |
| os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) | |
| file_handler = RotatingFileHandler( | |
| LOG_FILE, maxBytes=10 * 1024 * 1024, backupCount=5, encoding="utf-8" # 10MB | |
| ) | |
| file_handler.setLevel(logging.DEBUG) | |
| file_handler.setFormatter(detailed_formatter) | |
| logger.addHandler(file_handler) | |
| except Exception as e: | |
| print(f"Warning: Could not create log file {LOG_FILE}: {e}") | |
| # Console handler | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setLevel(logging.INFO) | |
| console_handler.setFormatter(detailed_formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |
| logger = setup_logging() | |
| async def lifespan(app: FastAPI): | |
| """Lifespan context manager for startup/shutdown with health monitoring""" | |
| global executor | |
| logger.info("=" * 80) | |
| logger.info("Starting Julia Environment Server") | |
| logger.info(f"Max Workers: {MAX_WORKERS}") | |
| logger.info(f"Execution Timeout: {EXECUTION_TIMEOUT}s") | |
| logger.info(f"Log File: {LOG_FILE}") | |
| logger.info(f"Log Level: {LOG_LEVEL}") | |
| logger.info("=" * 80) | |
| # Startup: Create thread pool with error handling | |
| try: | |
| executor = ThreadPoolExecutor( | |
| max_workers=MAX_WORKERS, thread_name_prefix="julia_worker" | |
| ) | |
| logger.info(f"✅ Thread pool created with {MAX_WORKERS} workers") | |
| logger.info(f"✅ Julia Environment Server started successfully") | |
| print( | |
| f"✅ Julia Environment Server started with {MAX_WORKERS} concurrent workers" | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Failed to start server: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise | |
| yield | |
| # Shutdown: Cleanup with grace period | |
| logger.info("Shutting down Julia Environment Server...") | |
| try: | |
| executor.shutdown(wait=True, cancel_futures=False) | |
| logger.info("✅ All workers completed gracefully") | |
| except Exception as e: | |
| logger.error(f"Error during shutdown: {e}") | |
| logger.info("✅ Julia Environment Server shutdown complete") | |
| print("✅ Julia Environment Server shutdown complete") | |
| # Create FastAPI app with lifespan management | |
| app = FastAPI( | |
| title="Julia Environment Server", | |
| description="Async Julia code execution environment with concurrent request support and auto-recovery", | |
| version="2.1.0", | |
| lifespan=lifespan, | |
| ) | |
| # Global exception handler for uncaught errors | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| """Handle all uncaught exceptions to prevent worker crashes""" | |
| error_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| logger.error(f"[ERROR-{error_id}] Uncaught exception in {request.url.path}") | |
| logger.error(f"[ERROR-{error_id}] Request: {request.method} {request.url}") | |
| logger.error(f"[ERROR-{error_id}] Exception: {type(exc).__name__}: {exc}") | |
| logger.error(f"[ERROR-{error_id}] Traceback:\n{traceback.format_exc()}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal server error", | |
| "type": type(exc).__name__, | |
| "message": str(exc), | |
| "error_id": error_id, | |
| "timestamp": datetime.now().isoformat(), | |
| }, | |
| ) | |
| async def execute_julia_async( | |
| action: JuliaAction, request_id: str = None | |
| ) -> JuliaObservation: | |
| """ | |
| Execute Julia code asynchronously in thread pool with timeout and error recovery. | |
| This runs the CPU-bound Julia execution in a separate thread to avoid | |
| blocking the event loop, allowing the server to handle multiple requests | |
| concurrently. | |
| Features: | |
| - Timeout protection | |
| - Automatic retry on transient failures | |
| - Comprehensive error logging | |
| - Resource cleanup | |
| """ | |
| if request_id is None: | |
| request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| loop = asyncio.get_event_loop() | |
| max_retries = 2 | |
| retry_count = 0 | |
| logger.debug( | |
| f"[{request_id}] Starting Julia execution (timeout: {EXECUTION_TIMEOUT}s)" | |
| ) | |
| while retry_count <= max_retries: | |
| env = None | |
| try: | |
| # Create a fresh environment instance for this request | |
| # This ensures thread safety and allows concurrent execution | |
| env = JuliaCodeActEnv() | |
| # Run the blocking step() call in thread pool with timeout | |
| observation = await asyncio.wait_for( | |
| loop.run_in_executor(executor, env.step, action), | |
| timeout=EXECUTION_TIMEOUT, | |
| ) | |
| logger.debug(f"[{request_id}] Julia execution completed successfully") | |
| logger.debug( | |
| f"[{request_id}] Result: tests_passed={observation.tests_passed}, " | |
| f"tests_failed={observation.tests_failed}, reward={observation.reward}" | |
| ) | |
| return observation | |
| except asyncio.TimeoutError: | |
| retry_count += 1 | |
| logger.warning( | |
| f"[{request_id}] Julia execution timeout (attempt {retry_count}/{max_retries + 1})" | |
| ) | |
| if retry_count > max_retries: | |
| logger.error( | |
| f"[{request_id}] Julia execution failed after {max_retries + 1} attempts" | |
| ) | |
| # Return a failure observation | |
| return JuliaObservation( | |
| stdout="", | |
| stderr=f"Execution timeout after {EXECUTION_TIMEOUT}s", | |
| exit_code=-1, | |
| tests_passed=0, | |
| tests_failed=1, | |
| code_compiles=False, | |
| reward=0.0, | |
| done=True, | |
| ) | |
| # Wait a bit before retry | |
| await asyncio.sleep(0.5) | |
| except Exception as e: | |
| retry_count += 1 | |
| logger.error( | |
| f"[{request_id}] Julia execution error (attempt {retry_count}/{max_retries + 1}): {e}" | |
| ) | |
| logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}") | |
| if retry_count > max_retries: | |
| logger.error( | |
| f"[{request_id}] Julia execution failed permanently after {max_retries + 1} attempts" | |
| ) | |
| # Return a failure observation | |
| return JuliaObservation( | |
| stdout="", | |
| stderr=f"Execution error: {str(e)}", | |
| exit_code=-1, | |
| tests_passed=0, | |
| tests_failed=1, | |
| code_compiles=False, | |
| reward=0.0, | |
| done=True, | |
| ) | |
| # Wait a bit before retry | |
| await asyncio.sleep(0.5) | |
| finally: | |
| # Clean up environment resources if possible | |
| if env is not None: | |
| try: | |
| # Add any cleanup needed here | |
| del env | |
| except Exception as cleanup_error: | |
| logger.debug(f"[{request_id}] Cleanup warning: {cleanup_error}") | |
| async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]: | |
| """ | |
| Reset endpoint - returns initial observation. | |
| Creates a fresh environment instance for the new episode. | |
| """ | |
| request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| logger.info(f"[{request_id}] Reset request received") | |
| try: | |
| # Run reset in thread pool to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| env = JuliaCodeActEnv() | |
| observation = await asyncio.wait_for( | |
| loop.run_in_executor(executor, env.reset), | |
| timeout=30.0, # Reset should be quick | |
| ) | |
| # Serialize observation | |
| obs_dict = asdict(observation) | |
| reward = obs_dict.pop("reward", None) | |
| done = obs_dict.pop("done", False) | |
| obs_dict.pop("metadata", None) | |
| logger.info(f"[{request_id}] Reset completed successfully") | |
| return { | |
| "observation": obs_dict, | |
| "reward": reward, | |
| "done": done, | |
| } | |
| except asyncio.TimeoutError: | |
| logger.error(f"[{request_id}] Reset timeout") | |
| raise HTTPException(status_code=504, detail="Reset operation timed out") | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Reset error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=f"Reset failed: {str(e)}") | |
| async def step(request: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Step endpoint - executes Julia code and returns observation. | |
| Runs Julia code execution asynchronously to handle multiple concurrent requests. | |
| Each request gets its own environment instance for thread safety. | |
| """ | |
| request_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| try: | |
| action_data = request.get("action", {}) | |
| if not action_data: | |
| logger.warning(f"[{request_id}] Step request with empty action") | |
| raise HTTPException(status_code=400, detail="Action data is required") | |
| # Deserialize action | |
| metadata = action_data.pop("metadata", {}) | |
| action = JuliaAction(**action_data) | |
| action.metadata = metadata | |
| logger.info(f"[{request_id}] Step request received") | |
| logger.debug( | |
| f"[{request_id}] Action: core_code_length={len(action.core_code) if action.core_code else 0}, " | |
| f"test_code_length={len(action.test_code) if action.test_code else 0}" | |
| ) | |
| # Execute Julia code asynchronously with timeout and retry | |
| observation = await execute_julia_async(action, request_id) | |
| # Serialize observation | |
| obs_dict = asdict(observation) | |
| reward = obs_dict.pop("reward", None) | |
| done = obs_dict.pop("done", False) | |
| obs_dict.pop("metadata", None) | |
| logger.info( | |
| f"[{request_id}] Step completed - reward={reward}, " | |
| f"tests_passed={observation.tests_passed}, tests_failed={observation.tests_failed}" | |
| ) | |
| return { | |
| "observation": obs_dict, | |
| "reward": reward, | |
| "done": done, | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"[{request_id}] Step endpoint error: {e}") | |
| logger.error(f"[{request_id}] Traceback:\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=f"Step execution failed: {str(e)}") | |
| async def get_state() -> Dict[str, Any]: | |
| """ | |
| State endpoint - returns environment metadata and server health. | |
| Note: Since each request creates a fresh environment, this returns | |
| general server state rather than specific episode state. | |
| """ | |
| try: | |
| import psutil | |
| process = psutil.Process() | |
| memory_info = process.memory_info() | |
| return { | |
| "max_workers": MAX_WORKERS, | |
| "executor_type": "ThreadPoolExecutor", | |
| "status": "ready", | |
| "timeout": EXECUTION_TIMEOUT, | |
| "log_file": LOG_FILE, | |
| "memory_mb": memory_info.rss / 1024 / 1024, | |
| "threads": len(process.threads()), | |
| } | |
| except ImportError: | |
| # psutil not available, return basic info | |
| return { | |
| "max_workers": MAX_WORKERS, | |
| "executor_type": "ThreadPoolExecutor", | |
| "status": "ready", | |
| "timeout": EXECUTION_TIMEOUT, | |
| "log_file": LOG_FILE, | |
| } | |
| except Exception as e: | |
| logger.warning(f"Could not get full state info: {e}") | |
| return { | |
| "max_workers": MAX_WORKERS, | |
| "executor_type": "ThreadPoolExecutor", | |
| "status": "ready", | |
| } | |
| async def health() -> Dict[str, str]: | |
| """ | |
| Health check endpoint. | |
| Returns healthy status if the server is operational and can accept requests. | |
| """ | |
| try: | |
| # Quick health check - verify executor is available | |
| if executor is None: | |
| logger.error("Health check failed: executor not initialized") | |
| raise HTTPException(status_code=503, detail="Service not ready") | |
| return { | |
| "status": "healthy", | |
| "workers": str(MAX_WORKERS), | |
| "timeout": str(EXECUTION_TIMEOUT), | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Health check error: {e}") | |
| raise HTTPException(status_code=503, detail="Health check failed") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Run with uvicorn | |
| # Use multiple workers for even better concurrency | |
| uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |