Upload folder using huggingface_hub
Browse files- __init__.py +15 -15
- client.py +39 -39
- models.py +18 -18
- openenv.yaml +74 -48
- server/__init__.py +2 -2
- server/app.py +148 -148
- server/env.py +172 -172
- server/models.py +67 -67
- server/rag/__init__.py +2 -2
- server/rag/retriever.py +97 -97
- server/server_routes.py +1 -1
__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
-
from client import ModerationEnv, ModerationEnvAction, ModerationEnvObservation, ModerationEnvState
|
| 2 |
-
from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 3 |
-
|
| 4 |
-
__all__ = [
|
| 5 |
-
"ModerationEnv",
|
| 6 |
-
"ModerationEnvAction",
|
| 7 |
-
"ModerationEnvObservation",
|
| 8 |
-
"ModerationEnvState",
|
| 9 |
-
"Action",
|
| 10 |
-
"ActionType",
|
| 11 |
-
"Content",
|
| 12 |
-
"Observation",
|
| 13 |
-
"PolicyChunk",
|
| 14 |
-
"State",
|
| 15 |
-
"StepType",
|
| 16 |
]
|
|
|
|
| 1 |
+
from client import ModerationEnv, ModerationEnvAction, ModerationEnvObservation, ModerationEnvState
|
| 2 |
+
from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"ModerationEnv",
|
| 6 |
+
"ModerationEnvAction",
|
| 7 |
+
"ModerationEnvObservation",
|
| 8 |
+
"ModerationEnvState",
|
| 9 |
+
"Action",
|
| 10 |
+
"ActionType",
|
| 11 |
+
"Content",
|
| 12 |
+
"Observation",
|
| 13 |
+
"PolicyChunk",
|
| 14 |
+
"State",
|
| 15 |
+
"StepType",
|
| 16 |
]
|
client.py
CHANGED
|
@@ -1,40 +1,40 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
from openenv.core import EnvClient
|
| 6 |
-
from openenv.core.client_types import StepResult
|
| 7 |
-
|
| 8 |
-
try:
|
| 9 |
-
from .models import Action, Observation, State
|
| 10 |
-
except ImportError:
|
| 11 |
-
from models import Action, Observation, State
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class ModerationEnv(EnvClient[Action, Observation, State]):
|
| 15 |
-
|
| 16 |
-
def _step_payload(self, action: Action) -> dict[str, Any]:
|
| 17 |
-
return action.model_dump(mode="json")
|
| 18 |
-
|
| 19 |
-
def _parse_result(self, payload: dict[str, Any]) -> StepResult[Observation]:
|
| 20 |
-
observation_payload = payload.get("observation", {})
|
| 21 |
-
return StepResult(
|
| 22 |
-
observation=Observation(**observation_payload),
|
| 23 |
-
reward=payload.get("reward"),
|
| 24 |
-
done=bool(payload.get("done", False)),
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
def _parse_state(self, payload: dict[str, Any]) -> State:
|
| 28 |
-
return State(**payload)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
ModerationEnvAction = Action
|
| 32 |
-
ModerationEnvObservation = Observation
|
| 33 |
-
ModerationEnvState = State
|
| 34 |
-
|
| 35 |
-
__all__ = [
|
| 36 |
-
"ModerationEnv",
|
| 37 |
-
"ModerationEnvAction",
|
| 38 |
-
"ModerationEnvObservation",
|
| 39 |
-
"ModerationEnvState",
|
| 40 |
]
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from .models import Action, Observation, State
|
| 10 |
+
except ImportError:
|
| 11 |
+
from models import Action, Observation, State
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ModerationEnv(EnvClient[Action, Observation, State]):
|
| 15 |
+
|
| 16 |
+
def _step_payload(self, action: Action) -> dict[str, Any]:
|
| 17 |
+
return action.model_dump(mode="json")
|
| 18 |
+
|
| 19 |
+
def _parse_result(self, payload: dict[str, Any]) -> StepResult[Observation]:
|
| 20 |
+
observation_payload = payload.get("observation", {})
|
| 21 |
+
return StepResult(
|
| 22 |
+
observation=Observation(**observation_payload),
|
| 23 |
+
reward=payload.get("reward"),
|
| 24 |
+
done=bool(payload.get("done", False)),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def _parse_state(self, payload: dict[str, Any]) -> State:
|
| 28 |
+
return State(**payload)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ModerationEnvAction = Action
|
| 32 |
+
ModerationEnvObservation = Observation
|
| 33 |
+
ModerationEnvState = State
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
"ModerationEnv",
|
| 37 |
+
"ModerationEnvAction",
|
| 38 |
+
"ModerationEnvObservation",
|
| 39 |
+
"ModerationEnvState",
|
| 40 |
]
|
models.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
-
from server.models import (
|
| 2 |
-
Action,
|
| 3 |
-
ActionType,
|
| 4 |
-
Content,
|
| 5 |
-
Observation,
|
| 6 |
-
PolicyChunk,
|
| 7 |
-
State,
|
| 8 |
-
StepType,
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"Action",
|
| 13 |
-
"ActionType",
|
| 14 |
-
"Content",
|
| 15 |
-
"Observation",
|
| 16 |
-
"PolicyChunk",
|
| 17 |
-
"State",
|
| 18 |
-
"StepType",
|
| 19 |
]
|
|
|
|
| 1 |
+
from server.models import (
|
| 2 |
+
Action,
|
| 3 |
+
ActionType,
|
| 4 |
+
Content,
|
| 5 |
+
Observation,
|
| 6 |
+
PolicyChunk,
|
| 7 |
+
State,
|
| 8 |
+
StepType,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"Action",
|
| 13 |
+
"ActionType",
|
| 14 |
+
"Content",
|
| 15 |
+
"Observation",
|
| 16 |
+
"PolicyChunk",
|
| 17 |
+
"State",
|
| 18 |
+
"StepType",
|
| 19 |
]
|
openenv.yaml
CHANGED
|
@@ -1,62 +1,88 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
name: "openenv-multimodal-moderation"
|
| 4 |
-
type: "environment"
|
| 5 |
-
runtime: "docker"
|
| 6 |
-
app: "server.app:app"
|
| 7 |
-
port: 8000
|
| 8 |
-
|
| 9 |
description: >
|
| 10 |
OpenEnv-compliant multimodal content moderation environment with staged
|
| 11 |
analysis, policy retrieval, moderation, reviewer validation, and dense reward.
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
tasks:
|
| 14 |
-
-
|
| 15 |
-
|
| 16 |
-
description: "Agent must identify violent content and issue a remove action."
|
| 17 |
-
difficulty: "hard" # Explicit violence detection + removal is most complex
|
| 18 |
grader:
|
| 19 |
-
type:
|
| 20 |
threshold: 0.5
|
| 21 |
-
endpoint:
|
| 22 |
-
reward_field:
|
| 23 |
-
|
| 24 |
-
-
|
| 25 |
-
|
| 26 |
-
description: "Agent must identify nudity/sexual content and issue a flag action."
|
| 27 |
-
difficulty: "medium" # Visual content analysis required
|
| 28 |
grader:
|
| 29 |
-
type:
|
| 30 |
threshold: 0.5
|
| 31 |
-
endpoint:
|
| 32 |
-
reward_field:
|
| 33 |
-
|
| 34 |
-
-
|
| 35 |
-
|
| 36 |
-
description: "Agent must identify harassment and issue a flag action."
|
| 37 |
-
difficulty: "medium" # Context/pattern recognition needed
|
| 38 |
grader:
|
| 39 |
-
type:
|
| 40 |
threshold: 0.5
|
| 41 |
-
endpoint:
|
| 42 |
-
reward_field:
|
| 43 |
-
|
| 44 |
-
-
|
| 45 |
-
|
| 46 |
-
description: "Agent must identify misinformation and escalate for review."
|
| 47 |
-
difficulty: "hard" # Requires fact-checking nuance + escalation judgment
|
| 48 |
grader:
|
| 49 |
-
type:
|
| 50 |
threshold: 0.5
|
| 51 |
-
endpoint:
|
| 52 |
-
reward_field:
|
| 53 |
-
|
| 54 |
-
-
|
| 55 |
-
|
| 56 |
-
description: "Agent must correctly allow benign/safe content."
|
| 57 |
-
difficulty: "easy" # Baseline safe content recognition
|
| 58 |
grader:
|
| 59 |
-
type:
|
| 60 |
threshold: 0.5
|
| 61 |
-
endpoint:
|
| 62 |
-
reward_field:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: openenv-multimodal-moderation
|
| 2 |
+
version: "1.0.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
description: >
|
| 4 |
OpenEnv-compliant multimodal content moderation environment with staged
|
| 5 |
analysis, policy retrieval, moderation, reviewer validation, and dense reward.
|
| 6 |
+
framework: meta-pytorch-openenv
|
| 7 |
+
license: Apache-2.0
|
| 8 |
+
entrypoint: server:app
|
| 9 |
+
episode_flow:
|
| 10 |
+
- analyze
|
| 11 |
+
- retrieve_policy
|
| 12 |
+
- decide
|
| 13 |
+
- review
|
| 14 |
+
- finalize
|
| 15 |
+
models:
|
| 16 |
+
action:
|
| 17 |
+
fields:
|
| 18 |
+
action_type: allow|flag|remove|escalate
|
| 19 |
+
reason: string
|
| 20 |
+
observation:
|
| 21 |
+
fields:
|
| 22 |
+
content: text + image metadata
|
| 23 |
+
policy: retrieved policy chunks
|
| 24 |
+
step_type: analyze|retrieve_policy|decide|review|finalize
|
| 25 |
+
step_count: integer
|
| 26 |
+
state:
|
| 27 |
+
fields:
|
| 28 |
+
episode_id: string
|
| 29 |
+
step_count: integer
|
| 30 |
+
done: boolean
|
| 31 |
+
policy_retrieval:
|
| 32 |
+
source: server/rag/policies.json
|
| 33 |
+
strategy: keyword_overlap
|
| 34 |
+
top_k: 3
|
| 35 |
tasks:
|
| 36 |
+
- name: violence-remove
|
| 37 |
+
description: Agent must identify violent content and issue a remove action
|
|
|
|
|
|
|
| 38 |
grader:
|
| 39 |
+
type: reward_threshold
|
| 40 |
threshold: 0.5
|
| 41 |
+
endpoint: /episode_summary
|
| 42 |
+
reward_field: total_reward
|
| 43 |
+
|
| 44 |
+
- name: nudity-flag
|
| 45 |
+
description: Agent must identify nudity/sexual content and issue a flag action
|
|
|
|
|
|
|
| 46 |
grader:
|
| 47 |
+
type: reward_threshold
|
| 48 |
threshold: 0.5
|
| 49 |
+
endpoint: /episode_summary
|
| 50 |
+
reward_field: total_reward
|
| 51 |
+
|
| 52 |
+
- name: harassment-flag
|
| 53 |
+
description: Agent must identify harassment and issue a flag action
|
|
|
|
|
|
|
| 54 |
grader:
|
| 55 |
+
type: reward_threshold
|
| 56 |
threshold: 0.5
|
| 57 |
+
endpoint: /episode_summary
|
| 58 |
+
reward_field: total_reward
|
| 59 |
+
|
| 60 |
+
- name: misinfo-escalate
|
| 61 |
+
description: Agent must identify misinformation and escalate for review
|
|
|
|
|
|
|
| 62 |
grader:
|
| 63 |
+
type: reward_threshold
|
| 64 |
threshold: 0.5
|
| 65 |
+
endpoint: /episode_summary
|
| 66 |
+
reward_field: total_reward
|
| 67 |
+
|
| 68 |
+
- name: benign-allow
|
| 69 |
+
description: Agent must correctly allow benign/safe content
|
|
|
|
|
|
|
| 70 |
grader:
|
| 71 |
+
type: reward_threshold
|
| 72 |
threshold: 0.5
|
| 73 |
+
endpoint: /episode_summary
|
| 74 |
+
reward_field: total_reward
|
| 75 |
+
rewards:
|
| 76 |
+
analysis_step: 0.2
|
| 77 |
+
retrieval_step: 0.2
|
| 78 |
+
correct_decision: 1.0
|
| 79 |
+
reviewer_agreement: 0.2
|
| 80 |
+
unsafe_penalty: -0.6
|
| 81 |
+
server:
|
| 82 |
+
reset: POST /reset
|
| 83 |
+
step: POST /step
|
| 84 |
+
state: GET /state
|
| 85 |
+
state_full: GET /state_full
|
| 86 |
+
episode_summary: GET /episode_summary
|
| 87 |
+
schema: GET /schema
|
| 88 |
+
docs: GET /docs
|
server/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from .app import app
|
| 2 |
-
|
| 3 |
__all__ = ["app"]
|
|
|
|
| 1 |
+
from .app import app
|
| 2 |
+
|
| 3 |
__all__ = ["app"]
|
server/app.py
CHANGED
|
@@ -1,148 +1,148 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import traceback
|
| 4 |
-
from typing import Optional
|
| 5 |
-
|
| 6 |
-
from fastapi import FastAPI, HTTPException
|
| 7 |
-
from fastapi.responses import JSONResponse
|
| 8 |
-
from pydantic import BaseModel
|
| 9 |
-
|
| 10 |
-
try:
|
| 11 |
-
from .models import Action, Observation, State
|
| 12 |
-
from .env import ModerationEnvironment
|
| 13 |
-
from .logic import CASE_IDS
|
| 14 |
-
except ImportError:
|
| 15 |
-
from models import Action, Observation, State
|
| 16 |
-
from env import ModerationEnvironment
|
| 17 |
-
from logic import CASE_IDS
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# ---------------------------------------------------------------------------
|
| 21 |
-
# Single persistent environment — shared across ALL HTTP requests
|
| 22 |
-
# ---------------------------------------------------------------------------
|
| 23 |
-
_env = ModerationEnvironment()
|
| 24 |
-
|
| 25 |
-
app = FastAPI(
|
| 26 |
-
title="OpenEnv Multimodal Moderation",
|
| 27 |
-
description="Multimodal content moderation RL environment",
|
| 28 |
-
version="1.0.0",
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# ---------------------------------------------------------------------------
|
| 33 |
-
# Request schemas
|
| 34 |
-
# ---------------------------------------------------------------------------
|
| 35 |
-
|
| 36 |
-
class ResetOptions(BaseModel):
|
| 37 |
-
case_id: Optional[str] = None
|
| 38 |
-
seed: Optional[int] = None
|
| 39 |
-
episode_id: Optional[str] = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class ResetRequest(BaseModel):
|
| 43 |
-
options: Optional[ResetOptions] = None
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class StepRequest(BaseModel):
|
| 47 |
-
action: Action
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# ---------------------------------------------------------------------------
|
| 51 |
-
# Core OpenEnv endpoints
|
| 52 |
-
# ---------------------------------------------------------------------------
|
| 53 |
-
|
| 54 |
-
@app.post("/reset")
|
| 55 |
-
async def reset(req: Optional[ResetRequest] = None) -> JSONResponse:
|
| 56 |
-
try:
|
| 57 |
-
opts = (req.options if req and req.options else None) or ResetOptions()
|
| 58 |
-
obs: Observation = _env.reset(
|
| 59 |
-
seed=opts.seed,
|
| 60 |
-
episode_id=opts.episode_id,
|
| 61 |
-
case_id=opts.case_id or "",
|
| 62 |
-
)
|
| 63 |
-
return JSONResponse({
|
| 64 |
-
"observation": obs.model_dump(mode="json"),
|
| 65 |
-
"reward": 0.0,
|
| 66 |
-
"done": False,
|
| 67 |
-
})
|
| 68 |
-
except Exception as e:
|
| 69 |
-
traceback.print_exc()
|
| 70 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
@app.post("/step")
|
| 74 |
-
async def step(req: StepRequest) -> JSONResponse:
|
| 75 |
-
try:
|
| 76 |
-
obs: Observation = _env.step(req.action)
|
| 77 |
-
return JSONResponse({
|
| 78 |
-
"observation": obs.model_dump(mode="json"),
|
| 79 |
-
"reward": obs.reward,
|
| 80 |
-
"done": obs.done,
|
| 81 |
-
})
|
| 82 |
-
except RuntimeError as e:
|
| 83 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 84 |
-
except Exception as e:
|
| 85 |
-
traceback.print_exc()
|
| 86 |
-
raise HTTPException(status_code=500, detail=str(e))
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
@app.get("/state")
|
| 90 |
-
async def get_state() -> JSONResponse:
|
| 91 |
-
return JSONResponse(_env.state.model_dump(mode="json"))
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
@app.get("/schema")
|
| 95 |
-
async def schema() -> JSONResponse:
|
| 96 |
-
return JSONResponse({
|
| 97 |
-
"action": Action.model_json_schema(),
|
| 98 |
-
"observation": Observation.model_json_schema(),
|
| 99 |
-
"state": State.model_json_schema(),
|
| 100 |
-
})
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
# ---------------------------------------------------------------------------
|
| 104 |
-
# /episode_summary — read by the reward_threshold graders in openenv.yaml
|
| 105 |
-
# ---------------------------------------------------------------------------
|
| 106 |
-
|
| 107 |
-
@app.get("/episode_summary")
|
| 108 |
-
async def episode_summary() -> JSONResponse:
|
| 109 |
-
state = _env.state
|
| 110 |
-
breakdown = dict(state.reward_breakdown or {})
|
| 111 |
-
total_reward = round(sum(breakdown.values()), 4)
|
| 112 |
-
return JSONResponse({
|
| 113 |
-
"episode_id": state.episode_id,
|
| 114 |
-
"step_count": state.step_count,
|
| 115 |
-
"done": state.done,
|
| 116 |
-
"total_reward": total_reward,
|
| 117 |
-
"reward_breakdown": breakdown,
|
| 118 |
-
"final_action": state.final_action,
|
| 119 |
-
"reviewer_note": state.reviewer_note,
|
| 120 |
-
})
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
# ---------------------------------------------------------------------------
|
| 124 |
-
# Helper endpoints
|
| 125 |
-
# ---------------------------------------------------------------------------
|
| 126 |
-
|
| 127 |
-
@app.get("/cases")
|
| 128 |
-
async def list_cases() -> JSONResponse:
|
| 129 |
-
return JSONResponse({"cases": CASE_IDS})
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
@app.get("/state_full")
|
| 133 |
-
async def state_full() -> JSONResponse:
|
| 134 |
-
return JSONResponse(_env.state.model_dump(mode="json"))
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
@app.get("/health")
|
| 138 |
-
async def health() -> JSONResponse:
|
| 139 |
-
return JSONResponse({"status": "ok"})
|
| 140 |
-
|
| 141 |
-
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
|
| 142 |
-
import uvicorn
|
| 143 |
-
|
| 144 |
-
uvicorn.run(app, host=host, port=port)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
if __name__ == "__main__":
|
| 148 |
-
main()
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import traceback
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .models import Action, Observation, State
|
| 12 |
+
from .env import ModerationEnvironment
|
| 13 |
+
from .logic import CASE_IDS
|
| 14 |
+
except ImportError:
|
| 15 |
+
from models import Action, Observation, State
|
| 16 |
+
from env import ModerationEnvironment
|
| 17 |
+
from logic import CASE_IDS
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Single persistent environment — shared across ALL HTTP requests
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
_env = ModerationEnvironment()
|
| 24 |
+
|
| 25 |
+
app = FastAPI(
|
| 26 |
+
title="OpenEnv Multimodal Moderation",
|
| 27 |
+
description="Multimodal content moderation RL environment",
|
| 28 |
+
version="1.0.0",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Request schemas
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
class ResetOptions(BaseModel):
|
| 37 |
+
case_id: Optional[str] = None
|
| 38 |
+
seed: Optional[int] = None
|
| 39 |
+
episode_id: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ResetRequest(BaseModel):
|
| 43 |
+
options: Optional[ResetOptions] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class StepRequest(BaseModel):
|
| 47 |
+
action: Action
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Core OpenEnv endpoints
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
@app.post("/reset")
|
| 55 |
+
async def reset(req: Optional[ResetRequest] = None) -> JSONResponse:
|
| 56 |
+
try:
|
| 57 |
+
opts = (req.options if req and req.options else None) or ResetOptions()
|
| 58 |
+
obs: Observation = _env.reset(
|
| 59 |
+
seed=opts.seed,
|
| 60 |
+
episode_id=opts.episode_id,
|
| 61 |
+
case_id=opts.case_id or "",
|
| 62 |
+
)
|
| 63 |
+
return JSONResponse({
|
| 64 |
+
"observation": obs.model_dump(mode="json"),
|
| 65 |
+
"reward": 0.0,
|
| 66 |
+
"done": False,
|
| 67 |
+
})
|
| 68 |
+
except Exception as e:
|
| 69 |
+
traceback.print_exc()
|
| 70 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@app.post("/step")
|
| 74 |
+
async def step(req: StepRequest) -> JSONResponse:
|
| 75 |
+
try:
|
| 76 |
+
obs: Observation = _env.step(req.action)
|
| 77 |
+
return JSONResponse({
|
| 78 |
+
"observation": obs.model_dump(mode="json"),
|
| 79 |
+
"reward": obs.reward,
|
| 80 |
+
"done": obs.done,
|
| 81 |
+
})
|
| 82 |
+
except RuntimeError as e:
|
| 83 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 84 |
+
except Exception as e:
|
| 85 |
+
traceback.print_exc()
|
| 86 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@app.get("/state")
|
| 90 |
+
async def get_state() -> JSONResponse:
|
| 91 |
+
return JSONResponse(_env.state.model_dump(mode="json"))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.get("/schema")
|
| 95 |
+
async def schema() -> JSONResponse:
|
| 96 |
+
return JSONResponse({
|
| 97 |
+
"action": Action.model_json_schema(),
|
| 98 |
+
"observation": Observation.model_json_schema(),
|
| 99 |
+
"state": State.model_json_schema(),
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
# /episode_summary — read by the reward_threshold graders in openenv.yaml
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
@app.get("/episode_summary")
|
| 108 |
+
async def episode_summary() -> JSONResponse:
|
| 109 |
+
state = _env.state
|
| 110 |
+
breakdown = dict(state.reward_breakdown or {})
|
| 111 |
+
total_reward = round(sum(breakdown.values()), 4)
|
| 112 |
+
return JSONResponse({
|
| 113 |
+
"episode_id": state.episode_id,
|
| 114 |
+
"step_count": state.step_count,
|
| 115 |
+
"done": state.done,
|
| 116 |
+
"total_reward": total_reward,
|
| 117 |
+
"reward_breakdown": breakdown,
|
| 118 |
+
"final_action": state.final_action,
|
| 119 |
+
"reviewer_note": state.reviewer_note,
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
# Helper endpoints
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
|
| 127 |
+
@app.get("/cases")
|
| 128 |
+
async def list_cases() -> JSONResponse:
|
| 129 |
+
return JSONResponse({"cases": CASE_IDS})
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@app.get("/state_full")
|
| 133 |
+
async def state_full() -> JSONResponse:
|
| 134 |
+
return JSONResponse(_env.state.model_dump(mode="json"))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@app.get("/health")
|
| 138 |
+
async def health() -> JSONResponse:
|
| 139 |
+
return JSONResponse({"status": "ok"})
|
| 140 |
+
|
| 141 |
+
def main(host: str = "0.0.0.0", port: int = 8000) -> None:
|
| 142 |
+
import uvicorn
|
| 143 |
+
|
| 144 |
+
uvicorn.run(app, host=host, port=port)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
server/env.py
CHANGED
|
@@ -1,173 +1,173 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import uuid
|
| 4 |
-
from typing import Any, Dict, Optional
|
| 5 |
-
|
| 6 |
-
from openenv.core.env_server.interfaces import Environment
|
| 7 |
-
|
| 8 |
-
try:
|
| 9 |
-
from .models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 10 |
-
from .logic import (
|
| 11 |
-
CASE_IDS,
|
| 12 |
-
get_case,
|
| 13 |
-
get_expected_action,
|
| 14 |
-
compute_step_reward,
|
| 15 |
-
)
|
| 16 |
-
from .rag.retriever import retrieve_policy_chunks
|
| 17 |
-
except ImportError:
|
| 18 |
-
from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 19 |
-
from logic import (
|
| 20 |
-
CASE_IDS,
|
| 21 |
-
get_case,
|
| 22 |
-
get_expected_action,
|
| 23 |
-
compute_step_reward,
|
| 24 |
-
)
|
| 25 |
-
from rag.retriever import retrieve_policy_chunks
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# Episode step flow — each step() call advances to the next stage
|
| 29 |
-
EPISODE_FLOW = ["analyze", "retrieve_policy", "decide", "review", "finalize"]
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class ModerationEnvironment(Environment):
|
| 33 |
-
"""OpenEnv environment for multimodal content moderation."""
|
| 34 |
-
|
| 35 |
-
def __init__(self) -> None:
|
| 36 |
-
super().__init__()
|
| 37 |
-
self._state = State()
|
| 38 |
-
self._case: Optional[Dict[str, Any]] = None
|
| 39 |
-
self._current_step_index: int = 0
|
| 40 |
-
|
| 41 |
-
# ------------------------------------------------------------------
|
| 42 |
-
# OpenEnv interface
|
| 43 |
-
# ------------------------------------------------------------------
|
| 44 |
-
|
| 45 |
-
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
|
| 46 |
-
eid = episode_id or str(uuid.uuid4())
|
| 47 |
-
|
| 48 |
-
# Determine which case to use
|
| 49 |
-
# Allow caller to pass case_id via kwargs (used by inference.py)
|
| 50 |
-
case_id = kwargs.get("case_id")
|
| 51 |
-
if case_id and case_id in CASE_IDS:
|
| 52 |
-
chosen_id = case_id
|
| 53 |
-
elif seed is not None:
|
| 54 |
-
chosen_id = CASE_IDS[seed % len(CASE_IDS)]
|
| 55 |
-
else:
|
| 56 |
-
import random
|
| 57 |
-
chosen_id = random.choice(CASE_IDS)
|
| 58 |
-
|
| 59 |
-
self._case = get_case(chosen_id)
|
| 60 |
-
self._current_step_index = 0
|
| 61 |
-
|
| 62 |
-
self._state = State(
|
| 63 |
-
episode_id=eid,
|
| 64 |
-
step_count=0,
|
| 65 |
-
done=False,
|
| 66 |
-
selected_case_id=chosen_id,
|
| 67 |
-
reward_breakdown={
|
| 68 |
-
"analysis_step": 0.0,
|
| 69 |
-
"retrieval_step": 0.0,
|
| 70 |
-
"correct_decision": 0.0,
|
| 71 |
-
"reviewer_agreement": 0.0,
|
| 72 |
-
"unsafe_penalty": 0.0,
|
| 73 |
-
},
|
| 74 |
-
final_action=None,
|
| 75 |
-
reviewer_note=None,
|
| 76 |
-
action_history=[],
|
| 77 |
-
retrieved_policy_chunks=[],
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
content = Content(**self._case["content"])
|
| 81 |
-
return Observation(
|
| 82 |
-
content=content,
|
| 83 |
-
policy=[],
|
| 84 |
-
step_type=StepType.analyze,
|
| 85 |
-
step_count=0,
|
| 86 |
-
message=f"Episode started. Case: {chosen_id}. Begin with analysis.",
|
| 87 |
-
reward=0.0,
|
| 88 |
-
done=False,
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
def step(self, action: Action, **kwargs) -> Observation:
|
| 92 |
-
if self._case is None:
|
| 93 |
-
raise RuntimeError("Call reset() before step()")
|
| 94 |
-
|
| 95 |
-
if self._state.done:
|
| 96 |
-
return Observation(
|
| 97 |
-
step_type=StepType.finalize,
|
| 98 |
-
step_count=self._state.step_count,
|
| 99 |
-
message="Episode already finished.",
|
| 100 |
-
reward=0.0,
|
| 101 |
-
done=True,
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
step_name = EPISODE_FLOW[self._current_step_index]
|
| 105 |
-
reward = compute_step_reward(step_name, action.action_type.value, self._case)
|
| 106 |
-
|
| 107 |
-
# Record reward into breakdown
|
| 108 |
-
breakdown = self._state.reward_breakdown
|
| 109 |
-
if step_name == "analyze":
|
| 110 |
-
breakdown["analysis_step"] += reward
|
| 111 |
-
elif step_name == "retrieve_policy":
|
| 112 |
-
breakdown["retrieval_step"] += reward
|
| 113 |
-
elif step_name == "decide":
|
| 114 |
-
if reward > 0:
|
| 115 |
-
breakdown["correct_decision"] += reward
|
| 116 |
-
else:
|
| 117 |
-
breakdown["unsafe_penalty"] += reward
|
| 118 |
-
elif step_name == "review":
|
| 119 |
-
breakdown["reviewer_agreement"] += reward
|
| 120 |
-
|
| 121 |
-
# Record action history
|
| 122 |
-
self._state.action_history.append({
|
| 123 |
-
"step": step_name,
|
| 124 |
-
"action_type": action.action_type.value,
|
| 125 |
-
"reason": action.reason,
|
| 126 |
-
"reward": reward,
|
| 127 |
-
})
|
| 128 |
-
|
| 129 |
-
self._state.step_count += 1
|
| 130 |
-
self._current_step_index += 1
|
| 131 |
-
|
| 132 |
-
# Build observation for next step
|
| 133 |
-
policy_chunks: list[PolicyChunk] = []
|
| 134 |
-
message = ""
|
| 135 |
-
next_step_type = StepType.finalize
|
| 136 |
-
|
| 137 |
-
if step_name == "retrieve_policy":
|
| 138 |
-
# Actually retrieve now that we're done with retrieve_policy
|
| 139 |
-
raw_chunks = retrieve_policy_chunks(self._case["content"].get("text", ""), top_k=3)
|
| 140 |
-
policy_chunks = [PolicyChunk(**c) for c in raw_chunks]
|
| 141 |
-
self._state.retrieved_policy_chunks = policy_chunks
|
| 142 |
-
message = "Policy retrieved. Now make your moderation decision."
|
| 143 |
-
elif step_name == "analyze":
|
| 144 |
-
message = "Analysis complete. Retrieve relevant policy next."
|
| 145 |
-
elif step_name == "decide":
|
| 146 |
-
self._state.final_action = action.action_type.value
|
| 147 |
-
message = "Decision recorded. Awaiting reviewer validation."
|
| 148 |
-
elif step_name == "review":
|
| 149 |
-
self._state.reviewer_note = action.reason or "Reviewer note recorded."
|
| 150 |
-
message = "Review complete. Finalizing episode."
|
| 151 |
-
elif step_name == "finalize":
|
| 152 |
-
message = "Episode finalized."
|
| 153 |
-
|
| 154 |
-
done = self._current_step_index >= len(EPISODE_FLOW)
|
| 155 |
-
self._state.done = done
|
| 156 |
-
|
| 157 |
-
# Determine next step type for observation
|
| 158 |
-
if not done and self._current_step_index < len(EPISODE_FLOW):
|
| 159 |
-
next_step_type = StepType(EPISODE_FLOW[self._current_step_index])
|
| 160 |
-
|
| 161 |
-
return Observation(
|
| 162 |
-
content=Content(**self._case["content"]),
|
| 163 |
-
policy=policy_chunks or self._state.retrieved_policy_chunks,
|
| 164 |
-
step_type=next_step_type,
|
| 165 |
-
step_count=self._state.step_count,
|
| 166 |
-
message=message,
|
| 167 |
-
reward=reward,
|
| 168 |
-
done=done,
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
@property
|
| 172 |
-
def state(self) -> State:
|
| 173 |
return self._state
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
from openenv.core.env_server.interfaces import Environment
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from .models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 10 |
+
from .logic import (
|
| 11 |
+
CASE_IDS,
|
| 12 |
+
get_case,
|
| 13 |
+
get_expected_action,
|
| 14 |
+
compute_step_reward,
|
| 15 |
+
)
|
| 16 |
+
from .rag.retriever import retrieve_policy_chunks
|
| 17 |
+
except ImportError:
|
| 18 |
+
from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
|
| 19 |
+
from logic import (
|
| 20 |
+
CASE_IDS,
|
| 21 |
+
get_case,
|
| 22 |
+
get_expected_action,
|
| 23 |
+
compute_step_reward,
|
| 24 |
+
)
|
| 25 |
+
from rag.retriever import retrieve_policy_chunks
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Episode step flow — each step() call advances to the next stage
|
| 29 |
+
EPISODE_FLOW = ["analyze", "retrieve_policy", "decide", "review", "finalize"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ModerationEnvironment(Environment):
|
| 33 |
+
"""OpenEnv environment for multimodal content moderation."""
|
| 34 |
+
|
| 35 |
+
def __init__(self) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self._state = State()
|
| 38 |
+
self._case: Optional[Dict[str, Any]] = None
|
| 39 |
+
self._current_step_index: int = 0
|
| 40 |
+
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
# OpenEnv interface
|
| 43 |
+
# ------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
|
| 46 |
+
eid = episode_id or str(uuid.uuid4())
|
| 47 |
+
|
| 48 |
+
# Determine which case to use
|
| 49 |
+
# Allow caller to pass case_id via kwargs (used by inference.py)
|
| 50 |
+
case_id = kwargs.get("case_id")
|
| 51 |
+
if case_id and case_id in CASE_IDS:
|
| 52 |
+
chosen_id = case_id
|
| 53 |
+
elif seed is not None:
|
| 54 |
+
chosen_id = CASE_IDS[seed % len(CASE_IDS)]
|
| 55 |
+
else:
|
| 56 |
+
import random
|
| 57 |
+
chosen_id = random.choice(CASE_IDS)
|
| 58 |
+
|
| 59 |
+
self._case = get_case(chosen_id)
|
| 60 |
+
self._current_step_index = 0
|
| 61 |
+
|
| 62 |
+
self._state = State(
|
| 63 |
+
episode_id=eid,
|
| 64 |
+
step_count=0,
|
| 65 |
+
done=False,
|
| 66 |
+
selected_case_id=chosen_id,
|
| 67 |
+
reward_breakdown={
|
| 68 |
+
"analysis_step": 0.0,
|
| 69 |
+
"retrieval_step": 0.0,
|
| 70 |
+
"correct_decision": 0.0,
|
| 71 |
+
"reviewer_agreement": 0.0,
|
| 72 |
+
"unsafe_penalty": 0.0,
|
| 73 |
+
},
|
| 74 |
+
final_action=None,
|
| 75 |
+
reviewer_note=None,
|
| 76 |
+
action_history=[],
|
| 77 |
+
retrieved_policy_chunks=[],
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
content = Content(**self._case["content"])
|
| 81 |
+
return Observation(
|
| 82 |
+
content=content,
|
| 83 |
+
policy=[],
|
| 84 |
+
step_type=StepType.analyze,
|
| 85 |
+
step_count=0,
|
| 86 |
+
message=f"Episode started. Case: {chosen_id}. Begin with analysis.",
|
| 87 |
+
reward=0.0,
|
| 88 |
+
done=False,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def step(self, action: Action, **kwargs) -> Observation:
|
| 92 |
+
if self._case is None:
|
| 93 |
+
raise RuntimeError("Call reset() before step()")
|
| 94 |
+
|
| 95 |
+
if self._state.done:
|
| 96 |
+
return Observation(
|
| 97 |
+
step_type=StepType.finalize,
|
| 98 |
+
step_count=self._state.step_count,
|
| 99 |
+
message="Episode already finished.",
|
| 100 |
+
reward=0.0,
|
| 101 |
+
done=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
step_name = EPISODE_FLOW[self._current_step_index]
|
| 105 |
+
reward = compute_step_reward(step_name, action.action_type.value, self._case)
|
| 106 |
+
|
| 107 |
+
# Record reward into breakdown
|
| 108 |
+
breakdown = self._state.reward_breakdown
|
| 109 |
+
if step_name == "analyze":
|
| 110 |
+
breakdown["analysis_step"] += reward
|
| 111 |
+
elif step_name == "retrieve_policy":
|
| 112 |
+
breakdown["retrieval_step"] += reward
|
| 113 |
+
elif step_name == "decide":
|
| 114 |
+
if reward > 0:
|
| 115 |
+
breakdown["correct_decision"] += reward
|
| 116 |
+
else:
|
| 117 |
+
breakdown["unsafe_penalty"] += reward
|
| 118 |
+
elif step_name == "review":
|
| 119 |
+
breakdown["reviewer_agreement"] += reward
|
| 120 |
+
|
| 121 |
+
# Record action history
|
| 122 |
+
self._state.action_history.append({
|
| 123 |
+
"step": step_name,
|
| 124 |
+
"action_type": action.action_type.value,
|
| 125 |
+
"reason": action.reason,
|
| 126 |
+
"reward": reward,
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
self._state.step_count += 1
|
| 130 |
+
self._current_step_index += 1
|
| 131 |
+
|
| 132 |
+
# Build observation for next step
|
| 133 |
+
policy_chunks: list[PolicyChunk] = []
|
| 134 |
+
message = ""
|
| 135 |
+
next_step_type = StepType.finalize
|
| 136 |
+
|
| 137 |
+
if step_name == "retrieve_policy":
|
| 138 |
+
# Actually retrieve now that we're done with retrieve_policy
|
| 139 |
+
raw_chunks = retrieve_policy_chunks(self._case["content"].get("text", ""), top_k=3)
|
| 140 |
+
policy_chunks = [PolicyChunk(**c) for c in raw_chunks]
|
| 141 |
+
self._state.retrieved_policy_chunks = policy_chunks
|
| 142 |
+
message = "Policy retrieved. Now make your moderation decision."
|
| 143 |
+
elif step_name == "analyze":
|
| 144 |
+
message = "Analysis complete. Retrieve relevant policy next."
|
| 145 |
+
elif step_name == "decide":
|
| 146 |
+
self._state.final_action = action.action_type.value
|
| 147 |
+
message = "Decision recorded. Awaiting reviewer validation."
|
| 148 |
+
elif step_name == "review":
|
| 149 |
+
self._state.reviewer_note = action.reason or "Reviewer note recorded."
|
| 150 |
+
message = "Review complete. Finalizing episode."
|
| 151 |
+
elif step_name == "finalize":
|
| 152 |
+
message = "Episode finalized."
|
| 153 |
+
|
| 154 |
+
done = self._current_step_index >= len(EPISODE_FLOW)
|
| 155 |
+
self._state.done = done
|
| 156 |
+
|
| 157 |
+
# Determine next step type for observation
|
| 158 |
+
if not done and self._current_step_index < len(EPISODE_FLOW):
|
| 159 |
+
next_step_type = StepType(EPISODE_FLOW[self._current_step_index])
|
| 160 |
+
|
| 161 |
+
return Observation(
|
| 162 |
+
content=Content(**self._case["content"]),
|
| 163 |
+
policy=policy_chunks or self._state.retrieved_policy_chunks,
|
| 164 |
+
step_type=next_step_type,
|
| 165 |
+
step_count=self._state.step_count,
|
| 166 |
+
message=message,
|
| 167 |
+
reward=reward,
|
| 168 |
+
done=done,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def state(self) -> State:
|
| 173 |
return self._state
|
server/models.py
CHANGED
|
@@ -1,68 +1,68 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from enum import Enum
|
| 4 |
-
from typing import Any, Dict, List, Optional
|
| 5 |
-
|
| 6 |
-
from pydantic import BaseModel, Field
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class ActionType(str, Enum):
|
| 10 |
-
allow = "allow"
|
| 11 |
-
flag = "flag"
|
| 12 |
-
remove = "remove"
|
| 13 |
-
escalate = "escalate"
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class StepType(str, Enum):
|
| 17 |
-
analyze = "analyze"
|
| 18 |
-
retrieve_policy = "retrieve_policy"
|
| 19 |
-
decide = "decide"
|
| 20 |
-
review = "review"
|
| 21 |
-
finalize = "finalize"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Content(BaseModel):
|
| 25 |
-
text: str = ""
|
| 26 |
-
image_url: Optional[str] = None
|
| 27 |
-
image_description: Optional[str] = None
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class PolicyChunk(BaseModel):
|
| 31 |
-
policy_id: str = ""
|
| 32 |
-
text: str = ""
|
| 33 |
-
score: float = 0.0
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class Action(BaseModel):
|
| 37 |
-
action_type: ActionType
|
| 38 |
-
reason: str = ""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class Observation(BaseModel):
|
| 42 |
-
content: Optional[Content] = None
|
| 43 |
-
policy: List[PolicyChunk] = Field(default_factory=list)
|
| 44 |
-
step_type: StepType = StepType.analyze
|
| 45 |
-
step_count: int = 0
|
| 46 |
-
message: str = ""
|
| 47 |
-
reward: float = 0.0
|
| 48 |
-
done: bool = False
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class State(BaseModel):
|
| 52 |
-
episode_id: str = ""
|
| 53 |
-
step_count: int = 0
|
| 54 |
-
done: bool = False
|
| 55 |
-
selected_case_id: Optional[str] = None
|
| 56 |
-
reward_breakdown: Dict[str, float] = Field(
|
| 57 |
-
default_factory=lambda: {
|
| 58 |
-
"analysis_step": 0.0,
|
| 59 |
-
"retrieval_step": 0.0,
|
| 60 |
-
"correct_decision": 0.0,
|
| 61 |
-
"reviewer_agreement": 0.0,
|
| 62 |
-
"unsafe_penalty": 0.0,
|
| 63 |
-
}
|
| 64 |
-
)
|
| 65 |
-
final_action: Optional[str] = None
|
| 66 |
-
reviewer_note: Optional[str] = None
|
| 67 |
-
action_history: List[Dict[str, Any]] = Field(default_factory=list)
|
| 68 |
retrieved_policy_chunks: List[PolicyChunk] = Field(default_factory=list)
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ActionType(str, Enum):
|
| 10 |
+
allow = "allow"
|
| 11 |
+
flag = "flag"
|
| 12 |
+
remove = "remove"
|
| 13 |
+
escalate = "escalate"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class StepType(str, Enum):
|
| 17 |
+
analyze = "analyze"
|
| 18 |
+
retrieve_policy = "retrieve_policy"
|
| 19 |
+
decide = "decide"
|
| 20 |
+
review = "review"
|
| 21 |
+
finalize = "finalize"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Content(BaseModel):
|
| 25 |
+
text: str = ""
|
| 26 |
+
image_url: Optional[str] = None
|
| 27 |
+
image_description: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PolicyChunk(BaseModel):
|
| 31 |
+
policy_id: str = ""
|
| 32 |
+
text: str = ""
|
| 33 |
+
score: float = 0.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Action(BaseModel):
|
| 37 |
+
action_type: ActionType
|
| 38 |
+
reason: str = ""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Observation(BaseModel):
|
| 42 |
+
content: Optional[Content] = None
|
| 43 |
+
policy: List[PolicyChunk] = Field(default_factory=list)
|
| 44 |
+
step_type: StepType = StepType.analyze
|
| 45 |
+
step_count: int = 0
|
| 46 |
+
message: str = ""
|
| 47 |
+
reward: float = 0.0
|
| 48 |
+
done: bool = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class State(BaseModel):
|
| 52 |
+
episode_id: str = ""
|
| 53 |
+
step_count: int = 0
|
| 54 |
+
done: bool = False
|
| 55 |
+
selected_case_id: Optional[str] = None
|
| 56 |
+
reward_breakdown: Dict[str, float] = Field(
|
| 57 |
+
default_factory=lambda: {
|
| 58 |
+
"analysis_step": 0.0,
|
| 59 |
+
"retrieval_step": 0.0,
|
| 60 |
+
"correct_decision": 0.0,
|
| 61 |
+
"reviewer_agreement": 0.0,
|
| 62 |
+
"unsafe_penalty": 0.0,
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
final_action: Optional[str] = None
|
| 66 |
+
reviewer_note: Optional[str] = None
|
| 67 |
+
action_history: List[Dict[str, Any]] = Field(default_factory=list)
|
| 68 |
retrieved_policy_chunks: List[PolicyChunk] = Field(default_factory=list)
|
server/rag/__init__.py
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
from .retriever import retrieve_policy_chunks
|
| 2 |
-
|
| 3 |
__all__ = ["retrieve_policy_chunks"]
|
|
|
|
| 1 |
+
from .retriever import retrieve_policy_chunks
|
| 2 |
+
|
| 3 |
__all__ = ["retrieve_policy_chunks"]
|
server/rag/retriever.py
CHANGED
|
@@ -1,98 +1,98 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
from typing import Any, Dict, List
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
_POLICIES_PATH = os.path.join(os.path.dirname(__file__), "policies.json")
|
| 9 |
-
|
| 10 |
-
_policies_cache: List[Dict[str, Any]] = []
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _load_policies() -> List[Dict[str, Any]]:
|
| 14 |
-
global _policies_cache
|
| 15 |
-
if _policies_cache:
|
| 16 |
-
return _policies_cache
|
| 17 |
-
if os.path.exists(_POLICIES_PATH):
|
| 18 |
-
with open(_POLICIES_PATH, "r", encoding="utf-8") as f:
|
| 19 |
-
_policies_cache = json.load(f)
|
| 20 |
-
else:
|
| 21 |
-
# Fallback built-in policies when policies.json is absent
|
| 22 |
-
_policies_cache = [
|
| 23 |
-
{
|
| 24 |
-
"chunk_id": "violence-01",
|
| 25 |
-
"title": "Violence and Credible Threats",
|
| 26 |
-
"text": "Content that includes direct threats, instructions to kill, or praise for murder must be removed immediately. Escalate only when context is ambiguous or safety risk cannot be resolved with the available evidence."
|
| 27 |
-
},
|
| 28 |
-
{
|
| 29 |
-
"chunk_id": "violence-02",
|
| 30 |
-
"title": "Violence Context Review",
|
| 31 |
-
"text": "Fictional, quoted, or documentary references to violence may remain allowed when the speaker is clearly not endorsing harm. When intent is uncertain, flag or escalate instead of allowing outright."
|
| 32 |
-
},
|
| 33 |
-
{
|
| 34 |
-
"chunk_id": "nudity-01",
|
| 35 |
-
"title": "Adult Nudity",
|
| 36 |
-
"text": "Mentions or metadata indicating nudity should be flagged for human review when the material may be adult sexual content. Remove only if the content is explicit or paired with exploitative cues."
|
| 37 |
-
},
|
| 38 |
-
{
|
| 39 |
-
"chunk_id": "nudity-02",
|
| 40 |
-
"title": "Suggestive Imagery Metadata",
|
| 41 |
-
"text": "Image tags such as exposed body parts, provocative posing, or synthetic nude prompts increase risk and should bias the system toward flagging rather than allowing."
|
| 42 |
-
},
|
| 43 |
-
{
|
| 44 |
-
"chunk_id": "self-harm-01",
|
| 45 |
-
"title": "Self-Harm and Crisis",
|
| 46 |
-
"text": "Content asking for methods of self-harm or celebrating imminent self-harm should be escalated or removed depending on immediacy. Support-seeking or recovery discussion should generally remain allowed."
|
| 47 |
-
},
|
| 48 |
-
{
|
| 49 |
-
"chunk_id": "harassment-01",
|
| 50 |
-
"title": "Harassment and Abuse",
|
| 51 |
-
"text": "Targeted abuse, slurs, and repeated demeaning language should be flagged or removed based on severity. Benign disagreement or criticism alone is not enough for removal."
|
| 52 |
-
},
|
| 53 |
-
{
|
| 54 |
-
"chunk_id": "spam-01",
|
| 55 |
-
"title": "Spam and Low-Quality Promotion",
|
| 56 |
-
"text": "Commercial spam, scam bait, or repetitive promotional text should usually be flagged. Remove only when the content is clearly malicious or part of coordinated abuse."
|
| 57 |
-
},
|
| 58 |
-
{
|
| 59 |
-
"chunk_id": "misinfo-01",
|
| 60 |
-
"title": "Misinformation Handling",
|
| 61 |
-
"text": "Potential misinformation without direct physical harm usually requires flagging or escalation, not automatic removal. High-risk medical or civic deception should lean toward escalation."
|
| 62 |
-
},
|
| 63 |
-
{
|
| 64 |
-
"chunk_id": "graphic-01",
|
| 65 |
-
"title": "Graphic Media",
|
| 66 |
-
"text": "Graphic injury descriptions or metadata indicating gore should be removed or escalated when public safety and age exposure concerns are high."
|
| 67 |
-
},
|
| 68 |
-
{
|
| 69 |
-
"chunk_id": "default-01",
|
| 70 |
-
"title": "Default Safe Handling",
|
| 71 |
-
"text": "When evidence is weak and no policy trigger is present, prefer allowing the content. Use escalation only when confidence is low or policy signals conflict."
|
| 72 |
-
}
|
| 73 |
-
]
|
| 74 |
-
return _policies_cache
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def _keyword_score(text: str, policy_text: str) -> float:
|
| 78 |
-
"""Simple overlap score: fraction of content words found in policy text."""
|
| 79 |
-
content_words = set(text.lower().split())
|
| 80 |
-
policy_words = set(policy_text.lower().split())
|
| 81 |
-
if not content_words:
|
| 82 |
-
return 0.0
|
| 83 |
-
return len(content_words & policy_words) / len(content_words)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def retrieve_policy_chunks(query_text: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 87 |
-
"""Return top_k policy chunks most relevant to query_text."""
|
| 88 |
-
policies = _load_policies()
|
| 89 |
-
scored = [
|
| 90 |
-
{
|
| 91 |
-
"policy_id": p["chunk_id"],
|
| 92 |
-
"text": p["text"],
|
| 93 |
-
"score": _keyword_score(query_text, p["text"]),
|
| 94 |
-
}
|
| 95 |
-
for p in policies
|
| 96 |
-
]
|
| 97 |
-
scored.sort(key=lambda x: x["score"], reverse=True)
|
| 98 |
return scored[:top_k]
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_POLICIES_PATH = os.path.join(os.path.dirname(__file__), "policies.json")
|
| 9 |
+
|
| 10 |
+
_policies_cache: List[Dict[str, Any]] = []
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _load_policies() -> List[Dict[str, Any]]:
|
| 14 |
+
global _policies_cache
|
| 15 |
+
if _policies_cache:
|
| 16 |
+
return _policies_cache
|
| 17 |
+
if os.path.exists(_POLICIES_PATH):
|
| 18 |
+
with open(_POLICIES_PATH, "r", encoding="utf-8") as f:
|
| 19 |
+
_policies_cache = json.load(f)
|
| 20 |
+
else:
|
| 21 |
+
# Fallback built-in policies when policies.json is absent
|
| 22 |
+
_policies_cache = [
|
| 23 |
+
{
|
| 24 |
+
"chunk_id": "violence-01",
|
| 25 |
+
"title": "Violence and Credible Threats",
|
| 26 |
+
"text": "Content that includes direct threats, instructions to kill, or praise for murder must be removed immediately. Escalate only when context is ambiguous or safety risk cannot be resolved with the available evidence."
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"chunk_id": "violence-02",
|
| 30 |
+
"title": "Violence Context Review",
|
| 31 |
+
"text": "Fictional, quoted, or documentary references to violence may remain allowed when the speaker is clearly not endorsing harm. When intent is uncertain, flag or escalate instead of allowing outright."
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"chunk_id": "nudity-01",
|
| 35 |
+
"title": "Adult Nudity",
|
| 36 |
+
"text": "Mentions or metadata indicating nudity should be flagged for human review when the material may be adult sexual content. Remove only if the content is explicit or paired with exploitative cues."
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"chunk_id": "nudity-02",
|
| 40 |
+
"title": "Suggestive Imagery Metadata",
|
| 41 |
+
"text": "Image tags such as exposed body parts, provocative posing, or synthetic nude prompts increase risk and should bias the system toward flagging rather than allowing."
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"chunk_id": "self-harm-01",
|
| 45 |
+
"title": "Self-Harm and Crisis",
|
| 46 |
+
"text": "Content asking for methods of self-harm or celebrating imminent self-harm should be escalated or removed depending on immediacy. Support-seeking or recovery discussion should generally remain allowed."
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"chunk_id": "harassment-01",
|
| 50 |
+
"title": "Harassment and Abuse",
|
| 51 |
+
"text": "Targeted abuse, slurs, and repeated demeaning language should be flagged or removed based on severity. Benign disagreement or criticism alone is not enough for removal."
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"chunk_id": "spam-01",
|
| 55 |
+
"title": "Spam and Low-Quality Promotion",
|
| 56 |
+
"text": "Commercial spam, scam bait, or repetitive promotional text should usually be flagged. Remove only when the content is clearly malicious or part of coordinated abuse."
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"chunk_id": "misinfo-01",
|
| 60 |
+
"title": "Misinformation Handling",
|
| 61 |
+
"text": "Potential misinformation without direct physical harm usually requires flagging or escalation, not automatic removal. High-risk medical or civic deception should lean toward escalation."
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"chunk_id": "graphic-01",
|
| 65 |
+
"title": "Graphic Media",
|
| 66 |
+
"text": "Graphic injury descriptions or metadata indicating gore should be removed or escalated when public safety and age exposure concerns are high."
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"chunk_id": "default-01",
|
| 70 |
+
"title": "Default Safe Handling",
|
| 71 |
+
"text": "When evidence is weak and no policy trigger is present, prefer allowing the content. Use escalation only when confidence is low or policy signals conflict."
|
| 72 |
+
}
|
| 73 |
+
]
|
| 74 |
+
return _policies_cache
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _keyword_score(text: str, policy_text: str) -> float:
|
| 78 |
+
"""Simple overlap score: fraction of content words found in policy text."""
|
| 79 |
+
content_words = set(text.lower().split())
|
| 80 |
+
policy_words = set(policy_text.lower().split())
|
| 81 |
+
if not content_words:
|
| 82 |
+
return 0.0
|
| 83 |
+
return len(content_words & policy_words) / len(content_words)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def retrieve_policy_chunks(query_text: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
| 87 |
+
"""Return top_k policy chunks most relevant to query_text."""
|
| 88 |
+
policies = _load_policies()
|
| 89 |
+
scored = [
|
| 90 |
+
{
|
| 91 |
+
"policy_id": p["chunk_id"],
|
| 92 |
+
"text": p["text"],
|
| 93 |
+
"score": _keyword_score(query_text, p["text"]),
|
| 94 |
+
}
|
| 95 |
+
for p in policies
|
| 96 |
+
]
|
| 97 |
+
scored.sort(key=lambda x: x["score"], reverse=True)
|
| 98 |
return scored[:top_k]
|
server/server_routes.py
CHANGED
|
@@ -13,7 +13,7 @@ def register_routes(app, env) -> None:
|
|
| 13 |
async def episode_summary() -> JSONResponse:
|
| 14 |
state = env.state
|
| 15 |
breakdown = state.reward_breakdown or {}
|
| 16 |
-
total_reward =
|
| 17 |
return JSONResponse({
|
| 18 |
"episode_id": state.episode_id,
|
| 19 |
"step_count": state.step_count,
|
|
|
|
| 13 |
async def episode_summary() -> JSONResponse:
|
| 14 |
state = env.state
|
| 15 |
breakdown = state.reward_breakdown or {}
|
| 16 |
+
total_reward = max(0.01, min(0.99, float(sum(breakdown.values()))))
|
| 17 |
return JSONResponse({
|
| 18 |
"episode_id": state.episode_id,
|
| 19 |
"step_count": state.step_count,
|