Spaces:
Sleeping
Sleeping
| """ | |
| ContentGuardEnv — Policy Enforcement Gateway (v1.0) | |
| ================================================== | |
| This server acts as the primary interface between LLM agents and the | |
| ContentGuard moderation environment. It exposes a standardized OpenEnv | |
| OpenAPI/WebSocket interface to facilitate autonomous training and | |
| benchmarking across Meta community standards. | |
| Key Features: | |
| - Real-time policy-trace streaming via WebSockets. | |
| - Dynamic environment resetting for multi-task RLHF. | |
| - Automated grading & reward calculation. | |
| """ | |
| import os | |
| import json | |
| import time | |
| from typing import Dict, Any, Optional | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Field | |
| from openai import AsyncOpenAI | |
| # Internal module imports (Post-Restructure) | |
| from env import ContentGuardEnv | |
| app = FastAPI( | |
| title="ContentGuardEnv Gateway", | |
| description="Operational environment for Meta-scale Trust & Safety agent benchmarking.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # Resource Path Configuration | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| STATIC_DIR = os.path.join(BASE_DIR, "static") | |
| if os.path.exists(STATIC_DIR): | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| sessions: Dict[str, ContentGuardEnv] = {} | |
| # LLM Inference Client (Defaulting to Hackathon standard endpoints) | |
| DEFAULT_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") | |
| DEFAULT_API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") | |
| BUILD_TAG = os.environ.get("HF_SPACE_SHA") or os.environ.get("SPACE_BUILD_TAG") or os.environ.get("BUILD_TAG") or str(int(time.time())) | |
| def _is_placeholder_api_key(api_key: Optional[str]) -> bool: | |
| if not api_key: | |
| return True | |
| lowered = api_key.strip().lower() | |
| return lowered in {"sk-placeholder", "your_api_key", "changeme"} | |
| def _resolve_session_client(cfg: Optional[Dict[str, Any]]) -> tuple[Optional[AsyncOpenAI], str]: | |
| """Build a runtime client from UI config with provider-aware routing rules.""" | |
| if not cfg: | |
| return aclient, MODEL_NAME | |
| api_key = (cfg.get("api_key") or "").strip() | |
| base_url = (cfg.get("base_url") or "").strip() or DEFAULT_BASE_URL | |
| model = (cfg.get("model") or "").strip() or MODEL_NAME | |
| if _is_placeholder_api_key(api_key): | |
| # No runtime key provided: use server default if configured, otherwise deterministic grading fallback. | |
| return aclient, MODEL_NAME | |
| if api_key.startswith("hf_") and "openai.com" in base_url: | |
| base_url = "https://api-inference.huggingface.co/v1" | |
| elif api_key.startswith("sk-") and "huggingface.co" in base_url: | |
| base_url = "https://api.openai.com/v1" | |
| if api_key.startswith("hf_") and not (cfg.get("model") or "").strip(): | |
| model = "meta-llama/Llama-3-70b-instruct" | |
| return AsyncOpenAI(api_key=api_key, base_url=base_url), model | |
| def _build_demo_action(env: ContentGuardEnv) -> Dict[str, Any]: | |
| """Generate a task-valid deterministic action when live inference is unavailable.""" | |
| gt = env.ground_truth or {} | |
| if env.task_id == "easy": | |
| return {"violation": gt.get("violation", "safe")} | |
| if env.task_id == "medium": | |
| return { | |
| "action": gt.get("action", "no_action"), | |
| "severity": int(gt.get("severity", 3)), | |
| "reasoning": "Deterministic demo fallback due unavailable inference credentials.", | |
| } | |
| return { | |
| "ruling": gt.get("ruling", "upheld"), | |
| "policy_references": gt.get("policy_references", []), | |
| "explanation": "Deterministic fallback path used because model inference is unavailable.", | |
| "user_guidance": "Review platform standards and avoid repeating flagged behavior.", | |
| } | |
| aclient: Optional[AsyncOpenAI] = None | |
| if not _is_placeholder_api_key(DEFAULT_API_KEY): | |
| aclient = AsyncOpenAI(api_key=DEFAULT_API_KEY.strip(), base_url=DEFAULT_BASE_URL) | |
| class ResetRequest(BaseModel): | |
| task_id: str = Field(default="easy", description="Difficulty tier: easy | medium | hard") | |
| class StepRequest(BaseModel): | |
| action: Dict[str, Any] = Field(..., description="Agent moderation decision package") | |
| async def disable_cache_for_dashboard_assets(request: Request, call_next): | |
| response = await call_next(request) | |
| path = request.url.path | |
| if path == "/" or path.startswith("/static/"): | |
| response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" | |
| response.headers["Pragma"] = "no-cache" | |
| response.headers["Expires"] = "0" | |
| return response | |
| async def serve_dashboard(): | |
| """Serves the primary autonomous monitoring interface.""" | |
| index_path = os.path.join(STATIC_DIR, "index.html") | |
| if not os.path.exists(index_path): | |
| return HTMLResponse("ContentGuard Dashboard: Static assets not found. Check /server/static deployment.") | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| html = f.read() | |
| html = html.replace('/static/style.css', f'/static/style.css?v={BUILD_TAG}', 1) | |
| html = html.replace('/static/app.js', f'/static/app.js?v={BUILD_TAG}', 1) | |
| return HTMLResponse(content=html) | |
| async def reset_environment(req: ResetRequest = ResetRequest()): | |
| """Initializes a new moderation case and returns the observation state.""" | |
| env = ContentGuardEnv() | |
| obs = env.reset(task_id=req.task_id) | |
| sessions[env.episode_id] = env | |
| return obs | |
| async def process_step(episode_id: str, req: StepRequest): | |
| """Submits an agent decision and returns a policy-aligned reward signal.""" | |
| env = sessions.get(episode_id) | |
| if not env: | |
| raise HTTPException(404, f"Session '{episode_id}' not active or expired.") | |
| try: | |
| return await env.step(req.action, client=aclient, model=MODEL_NAME) | |
| except RuntimeError as e: | |
| raise HTTPException(400, f"Policy Engine Conflict: {str(e)}") | |
| async def get_env_state(episode_id: str): | |
| """Retrieves the full internal state of an active moderation episode.""" | |
| env = sessions.get(episode_id) | |
| if not env: | |
| raise HTTPException(404, f"Session '{episode_id}' not found.") | |
| return env.state() | |
| async def check_health(): | |
| return {"status": "operational", "active_sessions": len(sessions), "engine": "ContentGuardEnv"} | |
| async def policy_trace_socket(websocket: WebSocket): | |
| """Streams real-time reasoning traces and environment telemetry.""" | |
| await websocket.accept() | |
| env: ContentGuardEnv | None = None | |
| session_client: AsyncOpenAI | None = aclient | |
| session_model: str = MODEL_NAME | |
| try: | |
| while True: | |
| raw = await websocket.receive_text() | |
| try: | |
| msg = json.loads(raw) | |
| except json.JSONDecodeError: | |
| await websocket.send_json({"type": "error", "message": "Malformed WebSocket packet."}) | |
| continue | |
| cmd = msg.get("action") | |
| # Universal Credential Injector (Session-based) | |
| if "config" in msg: | |
| session_client, session_model = _resolve_session_client(msg.get("config")) | |
| if cmd == "reset": | |
| env = ContentGuardEnv() | |
| try: | |
| obs = env.reset(task_id=msg.get("task_id", "easy")) | |
| # Explicitly dump Pydantic model for WebSocket JSON serialization | |
| await websocket.send_json({"type": "reset", "observation": obs.model_dump()}) | |
| except ValueError as e: | |
| await websocket.send_json({"type": "error", "message": str(e)}) | |
| elif cmd == "step": | |
| if not env: | |
| await websocket.send_json({"type": "error", "message": "State conflict: Submit reset before step."}) | |
| continue | |
| try: | |
| result = await env.step(msg.get("data", {}), client=session_client, model=session_model) | |
| await websocket.send_json({"type": "step", "result": result}) | |
| except RuntimeError as e: | |
| await websocket.send_json({"type": "error", "message": f"Execution halted: {str(e)}"}) | |
| elif cmd == "run_agent": | |
| if not env: | |
| await websocket.send_json({"type": "error", "message": "Session inactive."}) | |
| continue | |
| if env.done: | |
| await websocket.send_json({"type": "error", "message": "Episode finished. Call reset() for a new case."}) | |
| continue | |
| try: | |
| await websocket.send_json({"type": "stream", "content": f"[START] ep={env.episode_id} task={env.task_id}\n"}) | |
| if session_client is None: | |
| raise RuntimeError("No API credentials configured.") | |
| sys_prompt = "Expert Safety Moderator. Respond with JSON only. Strictly align with platform policies." | |
| user_prompt = f"Policy Task: {env._task_config['description']}\n\nEvidence:\n{json.dumps(env.case)}\n\nSubmit ruling in JSON." | |
| stream = await session_client.chat.completions.create( | |
| model=session_model, | |
| messages=[{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}], | |
| temperature=0.0, | |
| stream=True | |
| ) | |
| full_response = "" | |
| async for chunk in stream: | |
| content = chunk.choices[0].delta.content | |
| if content: | |
| full_response += content | |
| await websocket.send_json({"type": "stream", "content": content}) | |
| if not full_response.strip(): | |
| raise ValueError("Model returned an empty response.") | |
| # Clean/Parse Output | |
| js_str = full_response.strip() | |
| if js_str.startswith("```"): | |
| js_str = js_str.split("```")[1] | |
| if js_str.startswith("json"): js_str = js_str[4:] | |
| action = json.loads(js_str.strip()) | |
| await websocket.send_json({"type": "stream", "content": f"\n\n[STEP] Policy Ingested: {json.dumps(action)}\n"}) | |
| result = await env.step(action, client=session_client, model=session_model) | |
| await websocket.send_json({"type": "step", "result": result}) | |
| await websocket.send_json({"type": "stream", "content": f"[END] Result: Success. Reward: {result['reward']:.4f}\n"}) | |
| except Exception as e: | |
| err_text = str(e) | |
| lowered_err = err_text.lower() | |
| if "invalid_api_key" in lowered_err or "incorrect api key" in lowered_err or "api key" in lowered_err or "401" in lowered_err: | |
| err_text = "Authentication failed for the configured provider." | |
| await websocket.send_json({"type": "stream", "content": f"\n\n[NOTICE] Inference Unavailable: {err_text}\nInitiating Passive Grader demo...\n"}) | |
| if env.done: | |
| await websocket.send_json({"type": "error", "message": "Episode finished. Call reset() for a new case."}) | |
| continue | |
| sim_action = _build_demo_action(env) | |
| try: | |
| result = await env.step(sim_action, client=None, model=session_model) | |
| await websocket.send_json({"type": "step", "result": result}) | |
| await websocket.send_json({"type": "stream", "content": f"\n[DEMO] Passive Ruling Emitted. Final Reward: {result['reward']:.4f}\n"}) | |
| except RuntimeError as step_error: | |
| await websocket.send_json({"type": "error", "message": str(step_error)}) | |
| elif cmd == "state": | |
| if env: await websocket.send_json({"type": "state", "state": env.state()}) | |
| except WebSocketDisconnect: | |
| pass | |
| def main(): | |
| """Server entry point for OpenEnv validation.""" | |
| print("🚀 Initializing ContentGuard Policy Gateway...") | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |
| if __name__ == "__main__": | |
| main() | |