williyam's picture
fix: clamp scores to [0.01, 0.99] in API boundary, UI, inference, and tests
a8780b7
"""
FastAPI server implementing the OpenEnv specification for Agentic RAG Gym.
Endpoints:
- POST /reset → Reset environment, start new episode
- POST /step → Take an action in the environment
- GET /state → Get current environment state
- GET /tasks → List available tasks
- POST /grade → Grade current episode
- GET /health → Health check
"""
from __future__ import annotations
import asyncio
import re
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from rag_master.agents import CriticAgent, PlannerAgent, ReasonerAgent, RetrieverAgent, VerifierAgent
from rag_master.config import AppSettings, get_settings
from rag_master.llm_client import OpenAICompatibleClient
from rag_master.logging_config import get_logger, setup_logging
from rag_master.models import AgentRole
from rag_master.orchestrator import Orchestrator
from rag_master.retriever import FAISSRetriever
from rag_master.rewards import CompositeRewardFunction, clamp_score
from domains.aerospace.config import AerospaceDomainConfig
from domains.legal_research.config import LegalResearchDomainConfig
from server.models import (
Action,
GradeResult,
HealthStatus,
Observation,
ResetResult,
StateResult,
StepResult,
TaskInfo,
)
logger = get_logger(__name__)
# --- Domain registry ---
DOMAIN_REGISTRY = {
"aerospace": AerospaceDomainConfig,
"legal_research": LegalResearchDomainConfig,
}
# --- Global state ---
_orchestrators: Dict[str, Orchestrator] = {}
_active_domain: str = "aerospace"
_settings: Optional[AppSettings] = None
_lock = asyncio.Lock()
# Regex pattern for sanitization: strip HTML tags and control chars
_SANITIZE_RE = re.compile(r'<[^>]+>|[\x00-\x08\x0b\x0c\x0e-\x1f]')
class ResetRequest(BaseModel):
"""Request body for /reset endpoint."""
task_id: Optional[str] = Field(default=None, description="Task ID to load.")
class GradeRequest(BaseModel):
"""Request body for /grade endpoint."""
task_id: Optional[str] = Field(default=None, description="Task ID to grade.")
def _sanitize(text: str) -> str:
"""Sanitize user input: strip HTML tags and control characters."""
return _SANITIZE_RE.sub('', text).strip()
async def _initialize_orchestrator(settings: AppSettings, domain_name: str = "aerospace") -> Orchestrator:
"""Initialize the orchestrator with all components for a given domain."""
domain_cls = DOMAIN_REGISTRY.get(domain_name)
if domain_cls is None:
raise ValueError(f"Unknown domain: {domain_name}. Available: {list(DOMAIN_REGISTRY.keys())}")
domain = domain_cls()
# Each domain gets its own subdirectory so they never share embeddings.
domain_index_dir = settings.faiss_index_dir / domain_name
retriever = FAISSRetriever(
index_dir=domain_index_dir,
embedding_model=settings.embedding_model,
dimension=settings.faiss_dimension,
)
llm = OpenAICompatibleClient(
base_url=settings.api_base_url,
api_key=settings.llm_api_key,
model_name=settings.model_name,
)
reward_fn = CompositeRewardFunction(llm_client=llm)
agents = {
AgentRole.RETRIEVER: RetrieverAgent(retriever=retriever, llm=llm),
AgentRole.REASONER: ReasonerAgent(llm=llm, system_prompt=domain.get_system_prompt()),
AgentRole.CRITIC: CriticAgent(llm=llm),
AgentRole.PLANNER: PlannerAgent(llm=llm),
AgentRole.VERIFIER: VerifierAgent(llm=llm),
}
orchestrator = Orchestrator(
domain_config=domain,
retriever=retriever,
llm_client=llm,
reward_function=reward_fn,
agents=agents,
)
docs = domain.get_documents()
count = await retriever.index_documents(docs)
logger.info("knowledge_base_indexed", document_count=count)
return orchestrator
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
global _orchestrators, _active_domain, _settings
_settings = get_settings()
setup_logging(_settings.log_level.value)
logger.info("server_starting", host=_settings.server_host, port=_settings.server_port)
# Initialize default domain (aerospace)
_active_domain = "aerospace"
_orchestrators["aerospace"] = await _initialize_orchestrator(_settings, "aerospace")
logger.info("orchestrator_initialized", domain="aerospace")
yield
logger.info("server_shutting_down")
app = FastAPI(
title="Agentic RAG Gym",
description="A reinforcement learning environment for training AI agents on Retrieval-Augmented Generation tasks across multiple domains.",
version="1.0.0",
lifespan=lifespan,
)
# --- Prometheus metrics ---
try:
from prometheus_fastapi_instrumentator import Instrumentator
Instrumentator(
should_group_status_codes=True,
should_ignore_untemplated=True,
excluded_handlers=["/metrics"],
).instrument(app).expose(app, endpoint="/metrics", include_in_schema=False)
except ImportError:
logger.warning("prometheus_not_available", msg="Install prometheus_fastapi_instrumentator for /metrics")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthStatus)
async def health() -> HealthStatus:
"""Health check endpoint."""
return HealthStatus()
@app.get("/domains")
async def list_domains() -> Dict[str, Any]:
"""List all available domains and the currently active one."""
return {
"domains": list(DOMAIN_REGISTRY.keys()),
"active": _active_domain,
}
class SwitchDomainRequest(BaseModel):
"""Request body for /domain/switch endpoint."""
domain: str = Field(description="Domain name to switch to.")
@app.post("/domain/switch")
async def switch_domain(request: SwitchDomainRequest) -> Dict[str, Any]:
"""Switch the active domain."""
global _active_domain
domain_name = _sanitize(request.domain)
if domain_name not in DOMAIN_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown domain: {domain_name}. Available: {list(DOMAIN_REGISTRY.keys())}",
)
async with _lock:
if domain_name not in _orchestrators:
if _settings is None:
raise HTTPException(status_code=503, detail="Settings not initialized")
_orchestrators[domain_name] = await _initialize_orchestrator(_settings, domain_name)
logger.info("orchestrator_initialized", domain=domain_name)
_active_domain = domain_name
return {"active": _active_domain, "message": f"Switched to {domain_name} domain"}
def _get_orchestrator() -> Orchestrator:
"""Get the currently active orchestrator."""
orch = _orchestrators.get(_active_domain)
if orch is None:
raise HTTPException(status_code=503, detail="Orchestrator not initialized")
return orch
@app.post("/reset")
async def reset(request: ResetRequest = ResetRequest()) -> Dict[str, Any]:
"""Reset the environment for a new episode."""
orchestrator = _get_orchestrator()
task_id = _sanitize(request.task_id) if request.task_id else None
try:
async with _lock:
observation = await orchestrator.reset(task_id=task_id)
return {"observation": observation, "done": False, "info": {"message": "Episode reset"}}
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/step")
async def step(action: Action) -> Dict[str, Any]:
"""Take an action in the environment."""
orchestrator = _get_orchestrator()
action_dict: Dict[str, Any] = {"type": action.type.value}
if action.query:
action_dict["query"] = _sanitize(action.query)
if action.answer:
action_dict["answer"] = _sanitize(action.answer)
# Merge extra parameters but block overriding type/query/answer
for k, v in action.parameters.items():
if k not in ("type", "query", "answer"):
action_dict[k] = v
try:
async with _lock:
result = await orchestrator.step(action_dict)
# Defense-in-depth: ensure reward fields are strictly within [0.01, 0.99].
if "reward" in result:
result["reward"] = clamp_score(float(result["reward"]))
info = result.get("info") or {}
if "step_reward" in info and info["step_reward"] is not None:
info["step_reward"] = clamp_score(float(info["step_reward"]))
if "episode_reward" in info and info["episode_reward"] is not None:
info["episode_reward"] = clamp_score(float(info["episode_reward"]))
return result
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.get("/state")
async def state() -> Dict[str, Any]:
"""Get the current environment state."""
orchestrator = _get_orchestrator()
return orchestrator.state()
@app.get("/tasks")
async def list_tasks() -> Dict[str, Any]:
"""List all available tasks."""
orchestrator = _get_orchestrator()
tasks = [
TaskInfo(
task_id=t.task_id,
name=t.name,
description=t.description,
difficulty=t.difficulty.value,
max_steps=t.max_steps,
).model_dump()
for t in orchestrator.tasks
]
return {"tasks": tasks}
@app.post("/grade")
async def grade(request: GradeRequest = GradeRequest()) -> Dict[str, Any]:
"""Grade the current episode."""
orchestrator = _get_orchestrator()
try:
async with _lock:
score = await orchestrator.grade(task_id=request.task_id)
state_data = orchestrator.state()
return GradeResult(
task_id=state_data.get("task_id", ""),
score=clamp_score(float(score)),
episode_id=state_data.get("episode_id", ""),
).model_dump()
except Exception as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""Global exception handler."""
logger.error("unhandled_exception", error=str(exc), path=request.url.path)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error"},
)