Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server implementing the OpenEnv specification for Agentic RAG Gym. | |
| Endpoints: | |
| - POST /reset → Reset environment, start new episode | |
| - POST /step → Take an action in the environment | |
| - GET /state → Get current environment state | |
| - GET /tasks → List available tasks | |
| - POST /grade → Grade current episode | |
| - GET /health → Health check | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import re | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Any, AsyncGenerator, Dict, Optional | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from rag_master.agents import CriticAgent, PlannerAgent, ReasonerAgent, RetrieverAgent, VerifierAgent | |
| from rag_master.config import AppSettings, get_settings | |
| from rag_master.llm_client import OpenAICompatibleClient | |
| from rag_master.logging_config import get_logger, setup_logging | |
| from rag_master.models import AgentRole | |
| from rag_master.orchestrator import Orchestrator | |
| from rag_master.retriever import FAISSRetriever | |
| from rag_master.rewards import CompositeRewardFunction, clamp_score | |
| from domains.aerospace.config import AerospaceDomainConfig | |
| from domains.legal_research.config import LegalResearchDomainConfig | |
| from server.models import ( | |
| Action, | |
| GradeResult, | |
| HealthStatus, | |
| Observation, | |
| ResetResult, | |
| StateResult, | |
| StepResult, | |
| TaskInfo, | |
| ) | |
| logger = get_logger(__name__) | |
| # --- Domain registry --- | |
| DOMAIN_REGISTRY = { | |
| "aerospace": AerospaceDomainConfig, | |
| "legal_research": LegalResearchDomainConfig, | |
| } | |
| # --- Global state --- | |
| _orchestrators: Dict[str, Orchestrator] = {} | |
| _active_domain: str = "aerospace" | |
| _settings: Optional[AppSettings] = None | |
| _lock = asyncio.Lock() | |
| # Regex pattern for sanitization: strip HTML tags and control chars | |
| _SANITIZE_RE = re.compile(r'<[^>]+>|[\x00-\x08\x0b\x0c\x0e-\x1f]') | |
| class ResetRequest(BaseModel): | |
| """Request body for /reset endpoint.""" | |
| task_id: Optional[str] = Field(default=None, description="Task ID to load.") | |
| class GradeRequest(BaseModel): | |
| """Request body for /grade endpoint.""" | |
| task_id: Optional[str] = Field(default=None, description="Task ID to grade.") | |
| def _sanitize(text: str) -> str: | |
| """Sanitize user input: strip HTML tags and control characters.""" | |
| return _SANITIZE_RE.sub('', text).strip() | |
| async def _initialize_orchestrator(settings: AppSettings, domain_name: str = "aerospace") -> Orchestrator: | |
| """Initialize the orchestrator with all components for a given domain.""" | |
| domain_cls = DOMAIN_REGISTRY.get(domain_name) | |
| if domain_cls is None: | |
| raise ValueError(f"Unknown domain: {domain_name}. Available: {list(DOMAIN_REGISTRY.keys())}") | |
| domain = domain_cls() | |
| # Each domain gets its own subdirectory so they never share embeddings. | |
| domain_index_dir = settings.faiss_index_dir / domain_name | |
| retriever = FAISSRetriever( | |
| index_dir=domain_index_dir, | |
| embedding_model=settings.embedding_model, | |
| dimension=settings.faiss_dimension, | |
| ) | |
| llm = OpenAICompatibleClient( | |
| base_url=settings.api_base_url, | |
| api_key=settings.llm_api_key, | |
| model_name=settings.model_name, | |
| ) | |
| reward_fn = CompositeRewardFunction(llm_client=llm) | |
| agents = { | |
| AgentRole.RETRIEVER: RetrieverAgent(retriever=retriever, llm=llm), | |
| AgentRole.REASONER: ReasonerAgent(llm=llm, system_prompt=domain.get_system_prompt()), | |
| AgentRole.CRITIC: CriticAgent(llm=llm), | |
| AgentRole.PLANNER: PlannerAgent(llm=llm), | |
| AgentRole.VERIFIER: VerifierAgent(llm=llm), | |
| } | |
| orchestrator = Orchestrator( | |
| domain_config=domain, | |
| retriever=retriever, | |
| llm_client=llm, | |
| reward_function=reward_fn, | |
| agents=agents, | |
| ) | |
| docs = domain.get_documents() | |
| count = await retriever.index_documents(docs) | |
| logger.info("knowledge_base_indexed", document_count=count) | |
| return orchestrator | |
| async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | |
| """Application lifespan handler.""" | |
| global _orchestrators, _active_domain, _settings | |
| _settings = get_settings() | |
| setup_logging(_settings.log_level.value) | |
| logger.info("server_starting", host=_settings.server_host, port=_settings.server_port) | |
| # Initialize default domain (aerospace) | |
| _active_domain = "aerospace" | |
| _orchestrators["aerospace"] = await _initialize_orchestrator(_settings, "aerospace") | |
| logger.info("orchestrator_initialized", domain="aerospace") | |
| yield | |
| logger.info("server_shutting_down") | |
| app = FastAPI( | |
| title="Agentic RAG Gym", | |
| description="A reinforcement learning environment for training AI agents on Retrieval-Augmented Generation tasks across multiple domains.", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # --- Prometheus metrics --- | |
| try: | |
| from prometheus_fastapi_instrumentator import Instrumentator | |
| Instrumentator( | |
| should_group_status_codes=True, | |
| should_ignore_untemplated=True, | |
| excluded_handlers=["/metrics"], | |
| ).instrument(app).expose(app, endpoint="/metrics", include_in_schema=False) | |
| except ImportError: | |
| logger.warning("prometheus_not_available", msg="Install prometheus_fastapi_instrumentator for /metrics") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health() -> HealthStatus: | |
| """Health check endpoint.""" | |
| return HealthStatus() | |
| async def list_domains() -> Dict[str, Any]: | |
| """List all available domains and the currently active one.""" | |
| return { | |
| "domains": list(DOMAIN_REGISTRY.keys()), | |
| "active": _active_domain, | |
| } | |
| class SwitchDomainRequest(BaseModel): | |
| """Request body for /domain/switch endpoint.""" | |
| domain: str = Field(description="Domain name to switch to.") | |
| async def switch_domain(request: SwitchDomainRequest) -> Dict[str, Any]: | |
| """Switch the active domain.""" | |
| global _active_domain | |
| domain_name = _sanitize(request.domain) | |
| if domain_name not in DOMAIN_REGISTRY: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown domain: {domain_name}. Available: {list(DOMAIN_REGISTRY.keys())}", | |
| ) | |
| async with _lock: | |
| if domain_name not in _orchestrators: | |
| if _settings is None: | |
| raise HTTPException(status_code=503, detail="Settings not initialized") | |
| _orchestrators[domain_name] = await _initialize_orchestrator(_settings, domain_name) | |
| logger.info("orchestrator_initialized", domain=domain_name) | |
| _active_domain = domain_name | |
| return {"active": _active_domain, "message": f"Switched to {domain_name} domain"} | |
| def _get_orchestrator() -> Orchestrator: | |
| """Get the currently active orchestrator.""" | |
| orch = _orchestrators.get(_active_domain) | |
| if orch is None: | |
| raise HTTPException(status_code=503, detail="Orchestrator not initialized") | |
| return orch | |
| async def reset(request: ResetRequest = ResetRequest()) -> Dict[str, Any]: | |
| """Reset the environment for a new episode.""" | |
| orchestrator = _get_orchestrator() | |
| task_id = _sanitize(request.task_id) if request.task_id else None | |
| try: | |
| async with _lock: | |
| observation = await orchestrator.reset(task_id=task_id) | |
| return {"observation": observation, "done": False, "info": {"message": "Episode reset"}} | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def step(action: Action) -> Dict[str, Any]: | |
| """Take an action in the environment.""" | |
| orchestrator = _get_orchestrator() | |
| action_dict: Dict[str, Any] = {"type": action.type.value} | |
| if action.query: | |
| action_dict["query"] = _sanitize(action.query) | |
| if action.answer: | |
| action_dict["answer"] = _sanitize(action.answer) | |
| # Merge extra parameters but block overriding type/query/answer | |
| for k, v in action.parameters.items(): | |
| if k not in ("type", "query", "answer"): | |
| action_dict[k] = v | |
| try: | |
| async with _lock: | |
| result = await orchestrator.step(action_dict) | |
| # Defense-in-depth: ensure reward fields are strictly within [0.01, 0.99]. | |
| if "reward" in result: | |
| result["reward"] = clamp_score(float(result["reward"])) | |
| info = result.get("info") or {} | |
| if "step_reward" in info and info["step_reward"] is not None: | |
| info["step_reward"] = clamp_score(float(info["step_reward"])) | |
| if "episode_reward" in info and info["episode_reward"] is not None: | |
| info["episode_reward"] = clamp_score(float(info["episode_reward"])) | |
| return result | |
| except RuntimeError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def state() -> Dict[str, Any]: | |
| """Get the current environment state.""" | |
| orchestrator = _get_orchestrator() | |
| return orchestrator.state() | |
| async def list_tasks() -> Dict[str, Any]: | |
| """List all available tasks.""" | |
| orchestrator = _get_orchestrator() | |
| tasks = [ | |
| TaskInfo( | |
| task_id=t.task_id, | |
| name=t.name, | |
| description=t.description, | |
| difficulty=t.difficulty.value, | |
| max_steps=t.max_steps, | |
| ).model_dump() | |
| for t in orchestrator.tasks | |
| ] | |
| return {"tasks": tasks} | |
| async def grade(request: GradeRequest = GradeRequest()) -> Dict[str, Any]: | |
| """Grade the current episode.""" | |
| orchestrator = _get_orchestrator() | |
| try: | |
| async with _lock: | |
| score = await orchestrator.grade(task_id=request.task_id) | |
| state_data = orchestrator.state() | |
| return GradeResult( | |
| task_id=state_data.get("task_id", ""), | |
| score=clamp_score(float(score)), | |
| episode_id=state_data.get("episode_id", ""), | |
| ).model_dump() | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) | |
| async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse: | |
| """Global exception handler.""" | |
| logger.error("unhandled_exception", error=str(exc), path=request.url.path) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error"}, | |
| ) | |