SujanMidatani commited on
Commit
9e245c9
·
verified ·
1 Parent(s): 1cde9d1

Upload folder using huggingface_hub

Browse files
Files changed (11) hide show
  1. __init__.py +15 -15
  2. client.py +39 -39
  3. models.py +18 -18
  4. openenv.yaml +74 -48
  5. server/__init__.py +2 -2
  6. server/app.py +148 -148
  7. server/env.py +172 -172
  8. server/models.py +67 -67
  9. server/rag/__init__.py +2 -2
  10. server/rag/retriever.py +97 -97
  11. server/server_routes.py +1 -1
__init__.py CHANGED
@@ -1,16 +1,16 @@
1
- from client import ModerationEnv, ModerationEnvAction, ModerationEnvObservation, ModerationEnvState
2
- from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
3
-
4
- __all__ = [
5
- "ModerationEnv",
6
- "ModerationEnvAction",
7
- "ModerationEnvObservation",
8
- "ModerationEnvState",
9
- "Action",
10
- "ActionType",
11
- "Content",
12
- "Observation",
13
- "PolicyChunk",
14
- "State",
15
- "StepType",
16
  ]
 
1
+ from client import ModerationEnv, ModerationEnvAction, ModerationEnvObservation, ModerationEnvState
2
+ from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
3
+
4
+ __all__ = [
5
+ "ModerationEnv",
6
+ "ModerationEnvAction",
7
+ "ModerationEnvObservation",
8
+ "ModerationEnvState",
9
+ "Action",
10
+ "ActionType",
11
+ "Content",
12
+ "Observation",
13
+ "PolicyChunk",
14
+ "State",
15
+ "StepType",
16
  ]
client.py CHANGED
@@ -1,40 +1,40 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any
4
-
5
- from openenv.core import EnvClient
6
- from openenv.core.client_types import StepResult
7
-
8
- try:
9
- from .models import Action, Observation, State
10
- except ImportError:
11
- from models import Action, Observation, State
12
-
13
-
14
- class ModerationEnv(EnvClient[Action, Observation, State]):
15
-
16
- def _step_payload(self, action: Action) -> dict[str, Any]:
17
- return action.model_dump(mode="json")
18
-
19
- def _parse_result(self, payload: dict[str, Any]) -> StepResult[Observation]:
20
- observation_payload = payload.get("observation", {})
21
- return StepResult(
22
- observation=Observation(**observation_payload),
23
- reward=payload.get("reward"),
24
- done=bool(payload.get("done", False)),
25
- )
26
-
27
- def _parse_state(self, payload: dict[str, Any]) -> State:
28
- return State(**payload)
29
-
30
-
31
- ModerationEnvAction = Action
32
- ModerationEnvObservation = Observation
33
- ModerationEnvState = State
34
-
35
- __all__ = [
36
- "ModerationEnv",
37
- "ModerationEnvAction",
38
- "ModerationEnvObservation",
39
- "ModerationEnvState",
40
  ]
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+
8
+ try:
9
+ from .models import Action, Observation, State
10
+ except ImportError:
11
+ from models import Action, Observation, State
12
+
13
+
14
+ class ModerationEnv(EnvClient[Action, Observation, State]):
15
+
16
+ def _step_payload(self, action: Action) -> dict[str, Any]:
17
+ return action.model_dump(mode="json")
18
+
19
+ def _parse_result(self, payload: dict[str, Any]) -> StepResult[Observation]:
20
+ observation_payload = payload.get("observation", {})
21
+ return StepResult(
22
+ observation=Observation(**observation_payload),
23
+ reward=payload.get("reward"),
24
+ done=bool(payload.get("done", False)),
25
+ )
26
+
27
+ def _parse_state(self, payload: dict[str, Any]) -> State:
28
+ return State(**payload)
29
+
30
+
31
+ ModerationEnvAction = Action
32
+ ModerationEnvObservation = Observation
33
+ ModerationEnvState = State
34
+
35
+ __all__ = [
36
+ "ModerationEnv",
37
+ "ModerationEnvAction",
38
+ "ModerationEnvObservation",
39
+ "ModerationEnvState",
40
  ]
models.py CHANGED
@@ -1,19 +1,19 @@
1
- from server.models import (
2
- Action,
3
- ActionType,
4
- Content,
5
- Observation,
6
- PolicyChunk,
7
- State,
8
- StepType,
9
- )
10
-
11
- __all__ = [
12
- "Action",
13
- "ActionType",
14
- "Content",
15
- "Observation",
16
- "PolicyChunk",
17
- "State",
18
- "StepType",
19
  ]
 
1
+ from server.models import (
2
+ Action,
3
+ ActionType,
4
+ Content,
5
+ Observation,
6
+ PolicyChunk,
7
+ State,
8
+ StepType,
9
+ )
10
+
11
+ __all__ = [
12
+ "Action",
13
+ "ActionType",
14
+ "Content",
15
+ "Observation",
16
+ "PolicyChunk",
17
+ "State",
18
+ "StepType",
19
  ]
openenv.yaml CHANGED
@@ -1,62 +1,88 @@
1
- spec_version: "1"
2
-
3
- name: "openenv-multimodal-moderation"
4
- type: "environment"
5
- runtime: "docker"
6
- app: "server.app:app"
7
- port: 8000
8
-
9
  description: >
10
  OpenEnv-compliant multimodal content moderation environment with staged
11
  analysis, policy retrieval, moderation, reviewer validation, and dense reward.
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  tasks:
14
- - id: "violence-remove"
15
- name: "violence-remove"
16
- description: "Agent must identify violent content and issue a remove action."
17
- difficulty: "hard" # Explicit violence detection + removal is most complex
18
  grader:
19
- type: "reward_threshold"
20
  threshold: 0.5
21
- endpoint: "/episode_summary"
22
- reward_field: "total_reward"
23
-
24
- - id: "nudity-flag"
25
- name: "nudity-flag"
26
- description: "Agent must identify nudity/sexual content and issue a flag action."
27
- difficulty: "medium" # Visual content analysis required
28
  grader:
29
- type: "reward_threshold"
30
  threshold: 0.5
31
- endpoint: "/episode_summary"
32
- reward_field: "total_reward"
33
-
34
- - id: "harassment-flag"
35
- name: "harassment-flag"
36
- description: "Agent must identify harassment and issue a flag action."
37
- difficulty: "medium" # Context/pattern recognition needed
38
  grader:
39
- type: "reward_threshold"
40
  threshold: 0.5
41
- endpoint: "/episode_summary"
42
- reward_field: "total_reward"
43
-
44
- - id: "misinfo-escalate"
45
- name: "misinfo-escalate"
46
- description: "Agent must identify misinformation and escalate for review."
47
- difficulty: "hard" # Requires fact-checking nuance + escalation judgment
48
  grader:
49
- type: "reward_threshold"
50
  threshold: 0.5
51
- endpoint: "/episode_summary"
52
- reward_field: "total_reward"
53
-
54
- - id: "benign-allow"
55
- name: "benign-allow"
56
- description: "Agent must correctly allow benign/safe content."
57
- difficulty: "easy" # Baseline safe content recognition
58
  grader:
59
- type: "reward_threshold"
60
  threshold: 0.5
61
- endpoint: "/episode_summary"
62
- reward_field: "total_reward"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: openenv-multimodal-moderation
2
+ version: "1.0.0"
 
 
 
 
 
 
3
  description: >
4
  OpenEnv-compliant multimodal content moderation environment with staged
5
  analysis, policy retrieval, moderation, reviewer validation, and dense reward.
6
+ framework: meta-pytorch-openenv
7
+ license: Apache-2.0
8
+ entrypoint: server:app
9
+ episode_flow:
10
+ - analyze
11
+ - retrieve_policy
12
+ - decide
13
+ - review
14
+ - finalize
15
+ models:
16
+ action:
17
+ fields:
18
+ action_type: allow|flag|remove|escalate
19
+ reason: string
20
+ observation:
21
+ fields:
22
+ content: text + image metadata
23
+ policy: retrieved policy chunks
24
+ step_type: analyze|retrieve_policy|decide|review|finalize
25
+ step_count: integer
26
+ state:
27
+ fields:
28
+ episode_id: string
29
+ step_count: integer
30
+ done: boolean
31
+ policy_retrieval:
32
+ source: server/rag/policies.json
33
+ strategy: keyword_overlap
34
+ top_k: 3
35
  tasks:
36
+ - name: violence-remove
37
+ description: Agent must identify violent content and issue a remove action
 
 
38
  grader:
39
+ type: reward_threshold
40
  threshold: 0.5
41
+ endpoint: /episode_summary
42
+ reward_field: total_reward
43
+
44
+ - name: nudity-flag
45
+ description: Agent must identify nudity/sexual content and issue a flag action
 
 
46
  grader:
47
+ type: reward_threshold
48
  threshold: 0.5
49
+ endpoint: /episode_summary
50
+ reward_field: total_reward
51
+
52
+ - name: harassment-flag
53
+ description: Agent must identify harassment and issue a flag action
 
 
54
  grader:
55
+ type: reward_threshold
56
  threshold: 0.5
57
+ endpoint: /episode_summary
58
+ reward_field: total_reward
59
+
60
+ - name: misinfo-escalate
61
+ description: Agent must identify misinformation and escalate for review
 
 
62
  grader:
63
+ type: reward_threshold
64
  threshold: 0.5
65
+ endpoint: /episode_summary
66
+ reward_field: total_reward
67
+
68
+ - name: benign-allow
69
+ description: Agent must correctly allow benign/safe content
 
 
70
  grader:
71
+ type: reward_threshold
72
  threshold: 0.5
73
+ endpoint: /episode_summary
74
+ reward_field: total_reward
75
+ rewards:
76
+ analysis_step: 0.2
77
+ retrieval_step: 0.2
78
+ correct_decision: 1.0
79
+ reviewer_agreement: 0.2
80
+ unsafe_penalty: -0.6
81
+ server:
82
+ reset: POST /reset
83
+ step: POST /step
84
+ state: GET /state
85
+ state_full: GET /state_full
86
+ episode_summary: GET /episode_summary
87
+ schema: GET /schema
88
+ docs: GET /docs
server/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .app import app
2
-
3
  __all__ = ["app"]
 
1
+ from .app import app
2
+
3
  __all__ = ["app"]
server/app.py CHANGED
@@ -1,148 +1,148 @@
1
- from __future__ import annotations
2
-
3
- import traceback
4
- from typing import Optional
5
-
6
- from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import JSONResponse
8
- from pydantic import BaseModel
9
-
10
- try:
11
- from .models import Action, Observation, State
12
- from .env import ModerationEnvironment
13
- from .logic import CASE_IDS
14
- except ImportError:
15
- from models import Action, Observation, State
16
- from env import ModerationEnvironment
17
- from logic import CASE_IDS
18
-
19
-
20
- # ---------------------------------------------------------------------------
21
- # Single persistent environment — shared across ALL HTTP requests
22
- # ---------------------------------------------------------------------------
23
- _env = ModerationEnvironment()
24
-
25
- app = FastAPI(
26
- title="OpenEnv Multimodal Moderation",
27
- description="Multimodal content moderation RL environment",
28
- version="1.0.0",
29
- )
30
-
31
-
32
- # ---------------------------------------------------------------------------
33
- # Request schemas
34
- # ---------------------------------------------------------------------------
35
-
36
- class ResetOptions(BaseModel):
37
- case_id: Optional[str] = None
38
- seed: Optional[int] = None
39
- episode_id: Optional[str] = None
40
-
41
-
42
- class ResetRequest(BaseModel):
43
- options: Optional[ResetOptions] = None
44
-
45
-
46
- class StepRequest(BaseModel):
47
- action: Action
48
-
49
-
50
- # ---------------------------------------------------------------------------
51
- # Core OpenEnv endpoints
52
- # ---------------------------------------------------------------------------
53
-
54
- @app.post("/reset")
55
- async def reset(req: Optional[ResetRequest] = None) -> JSONResponse:
56
- try:
57
- opts = (req.options if req and req.options else None) or ResetOptions()
58
- obs: Observation = _env.reset(
59
- seed=opts.seed,
60
- episode_id=opts.episode_id,
61
- case_id=opts.case_id or "",
62
- )
63
- return JSONResponse({
64
- "observation": obs.model_dump(mode="json"),
65
- "reward": 0.0,
66
- "done": False,
67
- })
68
- except Exception as e:
69
- traceback.print_exc()
70
- raise HTTPException(status_code=500, detail=str(e))
71
-
72
-
73
- @app.post("/step")
74
- async def step(req: StepRequest) -> JSONResponse:
75
- try:
76
- obs: Observation = _env.step(req.action)
77
- return JSONResponse({
78
- "observation": obs.model_dump(mode="json"),
79
- "reward": obs.reward,
80
- "done": obs.done,
81
- })
82
- except RuntimeError as e:
83
- raise HTTPException(status_code=400, detail=str(e))
84
- except Exception as e:
85
- traceback.print_exc()
86
- raise HTTPException(status_code=500, detail=str(e))
87
-
88
-
89
- @app.get("/state")
90
- async def get_state() -> JSONResponse:
91
- return JSONResponse(_env.state.model_dump(mode="json"))
92
-
93
-
94
- @app.get("/schema")
95
- async def schema() -> JSONResponse:
96
- return JSONResponse({
97
- "action": Action.model_json_schema(),
98
- "observation": Observation.model_json_schema(),
99
- "state": State.model_json_schema(),
100
- })
101
-
102
-
103
- # ---------------------------------------------------------------------------
104
- # /episode_summary — read by the reward_threshold graders in openenv.yaml
105
- # ---------------------------------------------------------------------------
106
-
107
- @app.get("/episode_summary")
108
- async def episode_summary() -> JSONResponse:
109
- state = _env.state
110
- breakdown = dict(state.reward_breakdown or {})
111
- total_reward = round(sum(breakdown.values()), 4)
112
- return JSONResponse({
113
- "episode_id": state.episode_id,
114
- "step_count": state.step_count,
115
- "done": state.done,
116
- "total_reward": total_reward,
117
- "reward_breakdown": breakdown,
118
- "final_action": state.final_action,
119
- "reviewer_note": state.reviewer_note,
120
- })
121
-
122
-
123
- # ---------------------------------------------------------------------------
124
- # Helper endpoints
125
- # ---------------------------------------------------------------------------
126
-
127
- @app.get("/cases")
128
- async def list_cases() -> JSONResponse:
129
- return JSONResponse({"cases": CASE_IDS})
130
-
131
-
132
- @app.get("/state_full")
133
- async def state_full() -> JSONResponse:
134
- return JSONResponse(_env.state.model_dump(mode="json"))
135
-
136
-
137
- @app.get("/health")
138
- async def health() -> JSONResponse:
139
- return JSONResponse({"status": "ok"})
140
-
141
- def main(host: str = "0.0.0.0", port: int = 8000) -> None:
142
- import uvicorn
143
-
144
- uvicorn.run(app, host=host, port=port)
145
-
146
-
147
- if __name__ == "__main__":
148
- main()
 
1
+ from __future__ import annotations
2
+
3
+ import traceback
4
+ from typing import Optional
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel
9
+
10
+ try:
11
+ from .models import Action, Observation, State
12
+ from .env import ModerationEnvironment
13
+ from .logic import CASE_IDS
14
+ except ImportError:
15
+ from models import Action, Observation, State
16
+ from env import ModerationEnvironment
17
+ from logic import CASE_IDS
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Single persistent environment — shared across ALL HTTP requests
22
+ # ---------------------------------------------------------------------------
23
+ _env = ModerationEnvironment()
24
+
25
+ app = FastAPI(
26
+ title="OpenEnv Multimodal Moderation",
27
+ description="Multimodal content moderation RL environment",
28
+ version="1.0.0",
29
+ )
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Request schemas
34
+ # ---------------------------------------------------------------------------
35
+
36
+ class ResetOptions(BaseModel):
37
+ case_id: Optional[str] = None
38
+ seed: Optional[int] = None
39
+ episode_id: Optional[str] = None
40
+
41
+
42
+ class ResetRequest(BaseModel):
43
+ options: Optional[ResetOptions] = None
44
+
45
+
46
+ class StepRequest(BaseModel):
47
+ action: Action
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Core OpenEnv endpoints
52
+ # ---------------------------------------------------------------------------
53
+
54
+ @app.post("/reset")
55
+ async def reset(req: Optional[ResetRequest] = None) -> JSONResponse:
56
+ try:
57
+ opts = (req.options if req and req.options else None) or ResetOptions()
58
+ obs: Observation = _env.reset(
59
+ seed=opts.seed,
60
+ episode_id=opts.episode_id,
61
+ case_id=opts.case_id or "",
62
+ )
63
+ return JSONResponse({
64
+ "observation": obs.model_dump(mode="json"),
65
+ "reward": 0.0,
66
+ "done": False,
67
+ })
68
+ except Exception as e:
69
+ traceback.print_exc()
70
+ raise HTTPException(status_code=500, detail=str(e))
71
+
72
+
73
+ @app.post("/step")
74
+ async def step(req: StepRequest) -> JSONResponse:
75
+ try:
76
+ obs: Observation = _env.step(req.action)
77
+ return JSONResponse({
78
+ "observation": obs.model_dump(mode="json"),
79
+ "reward": obs.reward,
80
+ "done": obs.done,
81
+ })
82
+ except RuntimeError as e:
83
+ raise HTTPException(status_code=400, detail=str(e))
84
+ except Exception as e:
85
+ traceback.print_exc()
86
+ raise HTTPException(status_code=500, detail=str(e))
87
+
88
+
89
+ @app.get("/state")
90
+ async def get_state() -> JSONResponse:
91
+ return JSONResponse(_env.state.model_dump(mode="json"))
92
+
93
+
94
+ @app.get("/schema")
95
+ async def schema() -> JSONResponse:
96
+ return JSONResponse({
97
+ "action": Action.model_json_schema(),
98
+ "observation": Observation.model_json_schema(),
99
+ "state": State.model_json_schema(),
100
+ })
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # /episode_summary — read by the reward_threshold graders in openenv.yaml
105
+ # ---------------------------------------------------------------------------
106
+
107
+ @app.get("/episode_summary")
108
+ async def episode_summary() -> JSONResponse:
109
+ state = _env.state
110
+ breakdown = dict(state.reward_breakdown or {})
111
+ total_reward = round(sum(breakdown.values()), 4)
112
+ return JSONResponse({
113
+ "episode_id": state.episode_id,
114
+ "step_count": state.step_count,
115
+ "done": state.done,
116
+ "total_reward": total_reward,
117
+ "reward_breakdown": breakdown,
118
+ "final_action": state.final_action,
119
+ "reviewer_note": state.reviewer_note,
120
+ })
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Helper endpoints
125
+ # ---------------------------------------------------------------------------
126
+
127
+ @app.get("/cases")
128
+ async def list_cases() -> JSONResponse:
129
+ return JSONResponse({"cases": CASE_IDS})
130
+
131
+
132
+ @app.get("/state_full")
133
+ async def state_full() -> JSONResponse:
134
+ return JSONResponse(_env.state.model_dump(mode="json"))
135
+
136
+
137
+ @app.get("/health")
138
+ async def health() -> JSONResponse:
139
+ return JSONResponse({"status": "ok"})
140
+
141
+ def main(host: str = "0.0.0.0", port: int = 8000) -> None:
142
+ import uvicorn
143
+
144
+ uvicorn.run(app, host=host, port=port)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
server/env.py CHANGED
@@ -1,173 +1,173 @@
1
- from __future__ import annotations
2
-
3
- import uuid
4
- from typing import Any, Dict, Optional
5
-
6
- from openenv.core.env_server.interfaces import Environment
7
-
8
- try:
9
- from .models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
10
- from .logic import (
11
- CASE_IDS,
12
- get_case,
13
- get_expected_action,
14
- compute_step_reward,
15
- )
16
- from .rag.retriever import retrieve_policy_chunks
17
- except ImportError:
18
- from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
19
- from logic import (
20
- CASE_IDS,
21
- get_case,
22
- get_expected_action,
23
- compute_step_reward,
24
- )
25
- from rag.retriever import retrieve_policy_chunks
26
-
27
-
28
- # Episode step flow — each step() call advances to the next stage
29
- EPISODE_FLOW = ["analyze", "retrieve_policy", "decide", "review", "finalize"]
30
-
31
-
32
- class ModerationEnvironment(Environment):
33
- """OpenEnv environment for multimodal content moderation."""
34
-
35
- def __init__(self) -> None:
36
- super().__init__()
37
- self._state = State()
38
- self._case: Optional[Dict[str, Any]] = None
39
- self._current_step_index: int = 0
40
-
41
- # ------------------------------------------------------------------
42
- # OpenEnv interface
43
- # ------------------------------------------------------------------
44
-
45
- def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
46
- eid = episode_id or str(uuid.uuid4())
47
-
48
- # Determine which case to use
49
- # Allow caller to pass case_id via kwargs (used by inference.py)
50
- case_id = kwargs.get("case_id")
51
- if case_id and case_id in CASE_IDS:
52
- chosen_id = case_id
53
- elif seed is not None:
54
- chosen_id = CASE_IDS[seed % len(CASE_IDS)]
55
- else:
56
- import random
57
- chosen_id = random.choice(CASE_IDS)
58
-
59
- self._case = get_case(chosen_id)
60
- self._current_step_index = 0
61
-
62
- self._state = State(
63
- episode_id=eid,
64
- step_count=0,
65
- done=False,
66
- selected_case_id=chosen_id,
67
- reward_breakdown={
68
- "analysis_step": 0.0,
69
- "retrieval_step": 0.0,
70
- "correct_decision": 0.0,
71
- "reviewer_agreement": 0.0,
72
- "unsafe_penalty": 0.0,
73
- },
74
- final_action=None,
75
- reviewer_note=None,
76
- action_history=[],
77
- retrieved_policy_chunks=[],
78
- )
79
-
80
- content = Content(**self._case["content"])
81
- return Observation(
82
- content=content,
83
- policy=[],
84
- step_type=StepType.analyze,
85
- step_count=0,
86
- message=f"Episode started. Case: {chosen_id}. Begin with analysis.",
87
- reward=0.0,
88
- done=False,
89
- )
90
-
91
- def step(self, action: Action, **kwargs) -> Observation:
92
- if self._case is None:
93
- raise RuntimeError("Call reset() before step()")
94
-
95
- if self._state.done:
96
- return Observation(
97
- step_type=StepType.finalize,
98
- step_count=self._state.step_count,
99
- message="Episode already finished.",
100
- reward=0.0,
101
- done=True,
102
- )
103
-
104
- step_name = EPISODE_FLOW[self._current_step_index]
105
- reward = compute_step_reward(step_name, action.action_type.value, self._case)
106
-
107
- # Record reward into breakdown
108
- breakdown = self._state.reward_breakdown
109
- if step_name == "analyze":
110
- breakdown["analysis_step"] += reward
111
- elif step_name == "retrieve_policy":
112
- breakdown["retrieval_step"] += reward
113
- elif step_name == "decide":
114
- if reward > 0:
115
- breakdown["correct_decision"] += reward
116
- else:
117
- breakdown["unsafe_penalty"] += reward
118
- elif step_name == "review":
119
- breakdown["reviewer_agreement"] += reward
120
-
121
- # Record action history
122
- self._state.action_history.append({
123
- "step": step_name,
124
- "action_type": action.action_type.value,
125
- "reason": action.reason,
126
- "reward": reward,
127
- })
128
-
129
- self._state.step_count += 1
130
- self._current_step_index += 1
131
-
132
- # Build observation for next step
133
- policy_chunks: list[PolicyChunk] = []
134
- message = ""
135
- next_step_type = StepType.finalize
136
-
137
- if step_name == "retrieve_policy":
138
- # Actually retrieve now that we're done with retrieve_policy
139
- raw_chunks = retrieve_policy_chunks(self._case["content"].get("text", ""), top_k=3)
140
- policy_chunks = [PolicyChunk(**c) for c in raw_chunks]
141
- self._state.retrieved_policy_chunks = policy_chunks
142
- message = "Policy retrieved. Now make your moderation decision."
143
- elif step_name == "analyze":
144
- message = "Analysis complete. Retrieve relevant policy next."
145
- elif step_name == "decide":
146
- self._state.final_action = action.action_type.value
147
- message = "Decision recorded. Awaiting reviewer validation."
148
- elif step_name == "review":
149
- self._state.reviewer_note = action.reason or "Reviewer note recorded."
150
- message = "Review complete. Finalizing episode."
151
- elif step_name == "finalize":
152
- message = "Episode finalized."
153
-
154
- done = self._current_step_index >= len(EPISODE_FLOW)
155
- self._state.done = done
156
-
157
- # Determine next step type for observation
158
- if not done and self._current_step_index < len(EPISODE_FLOW):
159
- next_step_type = StepType(EPISODE_FLOW[self._current_step_index])
160
-
161
- return Observation(
162
- content=Content(**self._case["content"]),
163
- policy=policy_chunks or self._state.retrieved_policy_chunks,
164
- step_type=next_step_type,
165
- step_count=self._state.step_count,
166
- message=message,
167
- reward=reward,
168
- done=done,
169
- )
170
-
171
- @property
172
- def state(self) -> State:
173
  return self._state
 
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ from typing import Any, Dict, Optional
5
+
6
+ from openenv.core.env_server.interfaces import Environment
7
+
8
+ try:
9
+ from .models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
10
+ from .logic import (
11
+ CASE_IDS,
12
+ get_case,
13
+ get_expected_action,
14
+ compute_step_reward,
15
+ )
16
+ from .rag.retriever import retrieve_policy_chunks
17
+ except ImportError:
18
+ from models import Action, ActionType, Content, Observation, PolicyChunk, State, StepType
19
+ from logic import (
20
+ CASE_IDS,
21
+ get_case,
22
+ get_expected_action,
23
+ compute_step_reward,
24
+ )
25
+ from rag.retriever import retrieve_policy_chunks
26
+
27
+
28
+ # Episode step flow — each step() call advances to the next stage
29
+ EPISODE_FLOW = ["analyze", "retrieve_policy", "decide", "review", "finalize"]
30
+
31
+
32
+ class ModerationEnvironment(Environment):
33
+ """OpenEnv environment for multimodal content moderation."""
34
+
35
+ def __init__(self) -> None:
36
+ super().__init__()
37
+ self._state = State()
38
+ self._case: Optional[Dict[str, Any]] = None
39
+ self._current_step_index: int = 0
40
+
41
+ # ------------------------------------------------------------------
42
+ # OpenEnv interface
43
+ # ------------------------------------------------------------------
44
+
45
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs) -> Observation:
46
+ eid = episode_id or str(uuid.uuid4())
47
+
48
+ # Determine which case to use
49
+ # Allow caller to pass case_id via kwargs (used by inference.py)
50
+ case_id = kwargs.get("case_id")
51
+ if case_id and case_id in CASE_IDS:
52
+ chosen_id = case_id
53
+ elif seed is not None:
54
+ chosen_id = CASE_IDS[seed % len(CASE_IDS)]
55
+ else:
56
+ import random
57
+ chosen_id = random.choice(CASE_IDS)
58
+
59
+ self._case = get_case(chosen_id)
60
+ self._current_step_index = 0
61
+
62
+ self._state = State(
63
+ episode_id=eid,
64
+ step_count=0,
65
+ done=False,
66
+ selected_case_id=chosen_id,
67
+ reward_breakdown={
68
+ "analysis_step": 0.0,
69
+ "retrieval_step": 0.0,
70
+ "correct_decision": 0.0,
71
+ "reviewer_agreement": 0.0,
72
+ "unsafe_penalty": 0.0,
73
+ },
74
+ final_action=None,
75
+ reviewer_note=None,
76
+ action_history=[],
77
+ retrieved_policy_chunks=[],
78
+ )
79
+
80
+ content = Content(**self._case["content"])
81
+ return Observation(
82
+ content=content,
83
+ policy=[],
84
+ step_type=StepType.analyze,
85
+ step_count=0,
86
+ message=f"Episode started. Case: {chosen_id}. Begin with analysis.",
87
+ reward=0.0,
88
+ done=False,
89
+ )
90
+
91
+ def step(self, action: Action, **kwargs) -> Observation:
92
+ if self._case is None:
93
+ raise RuntimeError("Call reset() before step()")
94
+
95
+ if self._state.done:
96
+ return Observation(
97
+ step_type=StepType.finalize,
98
+ step_count=self._state.step_count,
99
+ message="Episode already finished.",
100
+ reward=0.0,
101
+ done=True,
102
+ )
103
+
104
+ step_name = EPISODE_FLOW[self._current_step_index]
105
+ reward = compute_step_reward(step_name, action.action_type.value, self._case)
106
+
107
+ # Record reward into breakdown
108
+ breakdown = self._state.reward_breakdown
109
+ if step_name == "analyze":
110
+ breakdown["analysis_step"] += reward
111
+ elif step_name == "retrieve_policy":
112
+ breakdown["retrieval_step"] += reward
113
+ elif step_name == "decide":
114
+ if reward > 0:
115
+ breakdown["correct_decision"] += reward
116
+ else:
117
+ breakdown["unsafe_penalty"] += reward
118
+ elif step_name == "review":
119
+ breakdown["reviewer_agreement"] += reward
120
+
121
+ # Record action history
122
+ self._state.action_history.append({
123
+ "step": step_name,
124
+ "action_type": action.action_type.value,
125
+ "reason": action.reason,
126
+ "reward": reward,
127
+ })
128
+
129
+ self._state.step_count += 1
130
+ self._current_step_index += 1
131
+
132
+ # Build observation for next step
133
+ policy_chunks: list[PolicyChunk] = []
134
+ message = ""
135
+ next_step_type = StepType.finalize
136
+
137
+ if step_name == "retrieve_policy":
138
+ # Actually retrieve now that we're done with retrieve_policy
139
+ raw_chunks = retrieve_policy_chunks(self._case["content"].get("text", ""), top_k=3)
140
+ policy_chunks = [PolicyChunk(**c) for c in raw_chunks]
141
+ self._state.retrieved_policy_chunks = policy_chunks
142
+ message = "Policy retrieved. Now make your moderation decision."
143
+ elif step_name == "analyze":
144
+ message = "Analysis complete. Retrieve relevant policy next."
145
+ elif step_name == "decide":
146
+ self._state.final_action = action.action_type.value
147
+ message = "Decision recorded. Awaiting reviewer validation."
148
+ elif step_name == "review":
149
+ self._state.reviewer_note = action.reason or "Reviewer note recorded."
150
+ message = "Review complete. Finalizing episode."
151
+ elif step_name == "finalize":
152
+ message = "Episode finalized."
153
+
154
+ done = self._current_step_index >= len(EPISODE_FLOW)
155
+ self._state.done = done
156
+
157
+ # Determine next step type for observation
158
+ if not done and self._current_step_index < len(EPISODE_FLOW):
159
+ next_step_type = StepType(EPISODE_FLOW[self._current_step_index])
160
+
161
+ return Observation(
162
+ content=Content(**self._case["content"]),
163
+ policy=policy_chunks or self._state.retrieved_policy_chunks,
164
+ step_type=next_step_type,
165
+ step_count=self._state.step_count,
166
+ message=message,
167
+ reward=reward,
168
+ done=done,
169
+ )
170
+
171
+ @property
172
+ def state(self) -> State:
173
  return self._state
server/models.py CHANGED
@@ -1,68 +1,68 @@
1
- from __future__ import annotations
2
-
3
- from enum import Enum
4
- from typing import Any, Dict, List, Optional
5
-
6
- from pydantic import BaseModel, Field
7
-
8
-
9
- class ActionType(str, Enum):
10
- allow = "allow"
11
- flag = "flag"
12
- remove = "remove"
13
- escalate = "escalate"
14
-
15
-
16
- class StepType(str, Enum):
17
- analyze = "analyze"
18
- retrieve_policy = "retrieve_policy"
19
- decide = "decide"
20
- review = "review"
21
- finalize = "finalize"
22
-
23
-
24
- class Content(BaseModel):
25
- text: str = ""
26
- image_url: Optional[str] = None
27
- image_description: Optional[str] = None
28
-
29
-
30
- class PolicyChunk(BaseModel):
31
- policy_id: str = ""
32
- text: str = ""
33
- score: float = 0.0
34
-
35
-
36
- class Action(BaseModel):
37
- action_type: ActionType
38
- reason: str = ""
39
-
40
-
41
- class Observation(BaseModel):
42
- content: Optional[Content] = None
43
- policy: List[PolicyChunk] = Field(default_factory=list)
44
- step_type: StepType = StepType.analyze
45
- step_count: int = 0
46
- message: str = ""
47
- reward: float = 0.0
48
- done: bool = False
49
-
50
-
51
- class State(BaseModel):
52
- episode_id: str = ""
53
- step_count: int = 0
54
- done: bool = False
55
- selected_case_id: Optional[str] = None
56
- reward_breakdown: Dict[str, float] = Field(
57
- default_factory=lambda: {
58
- "analysis_step": 0.0,
59
- "retrieval_step": 0.0,
60
- "correct_decision": 0.0,
61
- "reviewer_agreement": 0.0,
62
- "unsafe_penalty": 0.0,
63
- }
64
- )
65
- final_action: Optional[str] = None
66
- reviewer_note: Optional[str] = None
67
- action_history: List[Dict[str, Any]] = Field(default_factory=list)
68
  retrieved_policy_chunks: List[PolicyChunk] = Field(default_factory=list)
 
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class ActionType(str, Enum):
10
+ allow = "allow"
11
+ flag = "flag"
12
+ remove = "remove"
13
+ escalate = "escalate"
14
+
15
+
16
+ class StepType(str, Enum):
17
+ analyze = "analyze"
18
+ retrieve_policy = "retrieve_policy"
19
+ decide = "decide"
20
+ review = "review"
21
+ finalize = "finalize"
22
+
23
+
24
+ class Content(BaseModel):
25
+ text: str = ""
26
+ image_url: Optional[str] = None
27
+ image_description: Optional[str] = None
28
+
29
+
30
+ class PolicyChunk(BaseModel):
31
+ policy_id: str = ""
32
+ text: str = ""
33
+ score: float = 0.0
34
+
35
+
36
+ class Action(BaseModel):
37
+ action_type: ActionType
38
+ reason: str = ""
39
+
40
+
41
+ class Observation(BaseModel):
42
+ content: Optional[Content] = None
43
+ policy: List[PolicyChunk] = Field(default_factory=list)
44
+ step_type: StepType = StepType.analyze
45
+ step_count: int = 0
46
+ message: str = ""
47
+ reward: float = 0.0
48
+ done: bool = False
49
+
50
+
51
+ class State(BaseModel):
52
+ episode_id: str = ""
53
+ step_count: int = 0
54
+ done: bool = False
55
+ selected_case_id: Optional[str] = None
56
+ reward_breakdown: Dict[str, float] = Field(
57
+ default_factory=lambda: {
58
+ "analysis_step": 0.0,
59
+ "retrieval_step": 0.0,
60
+ "correct_decision": 0.0,
61
+ "reviewer_agreement": 0.0,
62
+ "unsafe_penalty": 0.0,
63
+ }
64
+ )
65
+ final_action: Optional[str] = None
66
+ reviewer_note: Optional[str] = None
67
+ action_history: List[Dict[str, Any]] = Field(default_factory=list)
68
  retrieved_policy_chunks: List[PolicyChunk] = Field(default_factory=list)
server/rag/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .retriever import retrieve_policy_chunks
2
-
3
  __all__ = ["retrieve_policy_chunks"]
 
1
+ from .retriever import retrieve_policy_chunks
2
+
3
  __all__ = ["retrieve_policy_chunks"]
server/rag/retriever.py CHANGED
@@ -1,98 +1,98 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from typing import Any, Dict, List
6
-
7
-
8
- _POLICIES_PATH = os.path.join(os.path.dirname(__file__), "policies.json")
9
-
10
- _policies_cache: List[Dict[str, Any]] = []
11
-
12
-
13
- def _load_policies() -> List[Dict[str, Any]]:
14
- global _policies_cache
15
- if _policies_cache:
16
- return _policies_cache
17
- if os.path.exists(_POLICIES_PATH):
18
- with open(_POLICIES_PATH, "r", encoding="utf-8") as f:
19
- _policies_cache = json.load(f)
20
- else:
21
- # Fallback built-in policies when policies.json is absent
22
- _policies_cache = [
23
- {
24
- "chunk_id": "violence-01",
25
- "title": "Violence and Credible Threats",
26
- "text": "Content that includes direct threats, instructions to kill, or praise for murder must be removed immediately. Escalate only when context is ambiguous or safety risk cannot be resolved with the available evidence."
27
- },
28
- {
29
- "chunk_id": "violence-02",
30
- "title": "Violence Context Review",
31
- "text": "Fictional, quoted, or documentary references to violence may remain allowed when the speaker is clearly not endorsing harm. When intent is uncertain, flag or escalate instead of allowing outright."
32
- },
33
- {
34
- "chunk_id": "nudity-01",
35
- "title": "Adult Nudity",
36
- "text": "Mentions or metadata indicating nudity should be flagged for human review when the material may be adult sexual content. Remove only if the content is explicit or paired with exploitative cues."
37
- },
38
- {
39
- "chunk_id": "nudity-02",
40
- "title": "Suggestive Imagery Metadata",
41
- "text": "Image tags such as exposed body parts, provocative posing, or synthetic nude prompts increase risk and should bias the system toward flagging rather than allowing."
42
- },
43
- {
44
- "chunk_id": "self-harm-01",
45
- "title": "Self-Harm and Crisis",
46
- "text": "Content asking for methods of self-harm or celebrating imminent self-harm should be escalated or removed depending on immediacy. Support-seeking or recovery discussion should generally remain allowed."
47
- },
48
- {
49
- "chunk_id": "harassment-01",
50
- "title": "Harassment and Abuse",
51
- "text": "Targeted abuse, slurs, and repeated demeaning language should be flagged or removed based on severity. Benign disagreement or criticism alone is not enough for removal."
52
- },
53
- {
54
- "chunk_id": "spam-01",
55
- "title": "Spam and Low-Quality Promotion",
56
- "text": "Commercial spam, scam bait, or repetitive promotional text should usually be flagged. Remove only when the content is clearly malicious or part of coordinated abuse."
57
- },
58
- {
59
- "chunk_id": "misinfo-01",
60
- "title": "Misinformation Handling",
61
- "text": "Potential misinformation without direct physical harm usually requires flagging or escalation, not automatic removal. High-risk medical or civic deception should lean toward escalation."
62
- },
63
- {
64
- "chunk_id": "graphic-01",
65
- "title": "Graphic Media",
66
- "text": "Graphic injury descriptions or metadata indicating gore should be removed or escalated when public safety and age exposure concerns are high."
67
- },
68
- {
69
- "chunk_id": "default-01",
70
- "title": "Default Safe Handling",
71
- "text": "When evidence is weak and no policy trigger is present, prefer allowing the content. Use escalation only when confidence is low or policy signals conflict."
72
- }
73
- ]
74
- return _policies_cache
75
-
76
-
77
- def _keyword_score(text: str, policy_text: str) -> float:
78
- """Simple overlap score: fraction of content words found in policy text."""
79
- content_words = set(text.lower().split())
80
- policy_words = set(policy_text.lower().split())
81
- if not content_words:
82
- return 0.0
83
- return len(content_words & policy_words) / len(content_words)
84
-
85
-
86
- def retrieve_policy_chunks(query_text: str, top_k: int = 3) -> List[Dict[str, Any]]:
87
- """Return top_k policy chunks most relevant to query_text."""
88
- policies = _load_policies()
89
- scored = [
90
- {
91
- "policy_id": p["chunk_id"],
92
- "text": p["text"],
93
- "score": _keyword_score(query_text, p["text"]),
94
- }
95
- for p in policies
96
- ]
97
- scored.sort(key=lambda x: x["score"], reverse=True)
98
  return scored[:top_k]
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, List
6
+
7
+
8
+ _POLICIES_PATH = os.path.join(os.path.dirname(__file__), "policies.json")
9
+
10
+ _policies_cache: List[Dict[str, Any]] = []
11
+
12
+
13
+ def _load_policies() -> List[Dict[str, Any]]:
14
+ global _policies_cache
15
+ if _policies_cache:
16
+ return _policies_cache
17
+ if os.path.exists(_POLICIES_PATH):
18
+ with open(_POLICIES_PATH, "r", encoding="utf-8") as f:
19
+ _policies_cache = json.load(f)
20
+ else:
21
+ # Fallback built-in policies when policies.json is absent
22
+ _policies_cache = [
23
+ {
24
+ "chunk_id": "violence-01",
25
+ "title": "Violence and Credible Threats",
26
+ "text": "Content that includes direct threats, instructions to kill, or praise for murder must be removed immediately. Escalate only when context is ambiguous or safety risk cannot be resolved with the available evidence."
27
+ },
28
+ {
29
+ "chunk_id": "violence-02",
30
+ "title": "Violence Context Review",
31
+ "text": "Fictional, quoted, or documentary references to violence may remain allowed when the speaker is clearly not endorsing harm. When intent is uncertain, flag or escalate instead of allowing outright."
32
+ },
33
+ {
34
+ "chunk_id": "nudity-01",
35
+ "title": "Adult Nudity",
36
+ "text": "Mentions or metadata indicating nudity should be flagged for human review when the material may be adult sexual content. Remove only if the content is explicit or paired with exploitative cues."
37
+ },
38
+ {
39
+ "chunk_id": "nudity-02",
40
+ "title": "Suggestive Imagery Metadata",
41
+ "text": "Image tags such as exposed body parts, provocative posing, or synthetic nude prompts increase risk and should bias the system toward flagging rather than allowing."
42
+ },
43
+ {
44
+ "chunk_id": "self-harm-01",
45
+ "title": "Self-Harm and Crisis",
46
+ "text": "Content asking for methods of self-harm or celebrating imminent self-harm should be escalated or removed depending on immediacy. Support-seeking or recovery discussion should generally remain allowed."
47
+ },
48
+ {
49
+ "chunk_id": "harassment-01",
50
+ "title": "Harassment and Abuse",
51
+ "text": "Targeted abuse, slurs, and repeated demeaning language should be flagged or removed based on severity. Benign disagreement or criticism alone is not enough for removal."
52
+ },
53
+ {
54
+ "chunk_id": "spam-01",
55
+ "title": "Spam and Low-Quality Promotion",
56
+ "text": "Commercial spam, scam bait, or repetitive promotional text should usually be flagged. Remove only when the content is clearly malicious or part of coordinated abuse."
57
+ },
58
+ {
59
+ "chunk_id": "misinfo-01",
60
+ "title": "Misinformation Handling",
61
+ "text": "Potential misinformation without direct physical harm usually requires flagging or escalation, not automatic removal. High-risk medical or civic deception should lean toward escalation."
62
+ },
63
+ {
64
+ "chunk_id": "graphic-01",
65
+ "title": "Graphic Media",
66
+ "text": "Graphic injury descriptions or metadata indicating gore should be removed or escalated when public safety and age exposure concerns are high."
67
+ },
68
+ {
69
+ "chunk_id": "default-01",
70
+ "title": "Default Safe Handling",
71
+ "text": "When evidence is weak and no policy trigger is present, prefer allowing the content. Use escalation only when confidence is low or policy signals conflict."
72
+ }
73
+ ]
74
+ return _policies_cache
75
+
76
+
77
+ def _keyword_score(text: str, policy_text: str) -> float:
78
+ """Simple overlap score: fraction of content words found in policy text."""
79
+ content_words = set(text.lower().split())
80
+ policy_words = set(policy_text.lower().split())
81
+ if not content_words:
82
+ return 0.0
83
+ return len(content_words & policy_words) / len(content_words)
84
+
85
+
86
+ def retrieve_policy_chunks(query_text: str, top_k: int = 3) -> List[Dict[str, Any]]:
87
+ """Return top_k policy chunks most relevant to query_text."""
88
+ policies = _load_policies()
89
+ scored = [
90
+ {
91
+ "policy_id": p["chunk_id"],
92
+ "text": p["text"],
93
+ "score": _keyword_score(query_text, p["text"]),
94
+ }
95
+ for p in policies
96
+ ]
97
+ scored.sort(key=lambda x: x["score"], reverse=True)
98
  return scored[:top_k]
server/server_routes.py CHANGED
@@ -13,7 +13,7 @@ def register_routes(app, env) -> None:
13
  async def episode_summary() -> JSONResponse:
14
  state = env.state
15
  breakdown = state.reward_breakdown or {}
16
- total_reward = round(sum(breakdown.values()), 4)
17
  return JSONResponse({
18
  "episode_id": state.episode_id,
19
  "step_count": state.step_count,
 
13
  async def episode_summary() -> JSONResponse:
14
  state = env.state
15
  breakdown = state.reward_breakdown or {}
16
+ total_reward = max(0.01, min(0.99, float(sum(breakdown.values()))))
17
  return JSONResponse({
18
  "episode_id": state.episode_id,
19
  "step_count": state.step_count,