File size: 7,255 Bytes
2a39e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf8abc
 
 
 
 
 
 
 
2a39e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf8abc
 
 
2a39e79
 
 
 
 
 
abf8abc
 
 
 
2a39e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf8abc
 
 
2a39e79
 
 
abf8abc
 
2a39e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
models.py
=========
Typed Pydantic data-models for the ContentModerationEnv OpenEnv spec.

These models are used for:
  β€’ runtime validation of agent actions
  β€’ serialisation of environment state
  β€’ documentation (JSON Schema derivable via .model_json_schema())

Requires: pydantic >= 2.0
"""

from __future__ import annotations

from enum import Enum
from typing import List, Optional

from pydantic import BaseModel, Field, field_validator, model_validator


# ── Enumerations ──────────────────────────────────────────────────────────────

class Label(str, Enum):
    """Content classification label."""
    safe       = "safe"
    toxic      = "toxic"
    spam       = "spam"
    misleading = "misleading"


class ModerationAction(str, Enum):
    """Moderation action to apply to the content."""
    allow      = "allow"
    warn       = "warn"
    remove     = "remove"
    shadowban  = "shadowban"
    escalate   = "escalate"


class PlatformPolicy(str, Enum):
    """Enforcement level of the platform."""
    strict   = "strict"
    moderate = "moderate"
    lenient  = "lenient"


class Platform(str, Enum):
    """Social media platform context."""
    reddit   = "reddit"
    twitter  = "twitter"
    youtube  = "youtube"
    linkedin = "linkedin"


class Tier(str, Enum):
    """Benchmark difficulty tier."""
    easy   = "easy"
    medium = "medium"
    hard   = "hard"


# ── Observation (state) ───────────────────────────────────────────────────────

class Observation(BaseModel):
    """
    The environment's observation returned by reset() / state().

    Attributes
    ----------
    text : str
        The user-generated text content to be reviewed.
    audio_transcript : str | None
        Transcript of any accompanying audio/video (None for text-only).
    visual_tags : list[str]
        Machine-detected visual content tags (empty list if no visual media).
    previous_flags : int
        Number of prior platform violations by this account (β‰₯ 0).
    platform_policy : PlatformPolicy
        Policy enforcement level the moderation decision must respect.
    platform : str | None
        Social media platform context: reddit / twitter / youtube / linkedin.
        None for legacy scenarios that predate the platform field.
    """
    text:              str              = Field(..., description="User-generated content text")
    audio_transcript:  Optional[str]    = Field(None, description="Audio/video transcript (nullable)")
    visual_tags:       List[str]        = Field(default_factory=list, description="Detected visual content tags")
    previous_flags:    int              = Field(..., ge=0, description="Prior policy violations count")
    platform_policy:   PlatformPolicy  = Field(..., description="Platform enforcement level")
    platform:          Optional[str]    = Field(
                           None,
                           description="Platform: reddit/twitter/youtube/linkedin"
                       )

    model_config = {"frozen": True}   # immutable β€” agents must not mutate state


# ── Action ────────────────────────────────────────────────────────────────────

class AgentAction(BaseModel):
    """
    The action an agent submits via env.step().

    Required
    --------
    label    : Label
    action   : ModerationAction

    Optional (scored only in hard tier)
    ------------------------------------
    severity  : int in [1, 5]
    rationale : str
    """
    label:     Label            = Field(..., description="Content classification")
    action:    ModerationAction = Field(..., description="Moderation action to apply")
    severity:  Optional[int]    = Field(None, ge=1, le=5, description="Severity 1-5 (hard tier only)")
    rationale: Optional[str]    = Field(None, description="Brief reasoning (not scored)")

    @field_validator("severity", mode="before")
    @classmethod
    def coerce_severity(cls, v):
        """Accept string integers gracefully."""
        if v is not None:
            return int(v)
        return v

    def to_env_dict(self) -> dict:
        """Convert to the plain dict format expected by ContentModerationEnv.step()."""
        d: dict = {"label": self.label.value, "action": self.action.value}
        if self.severity is not None:
            d["severity"] = self.severity
        if self.rationale is not None:
            d["rationale"] = self.rationale
        return d


# ── Score breakdown ───────────────────────────────────────────────────────────

class ScoreBreakdown(BaseModel):
    """Per-component reward breakdown returned in step() info dict."""
    label_correct:    Optional[float] = Field(None, ge=0.0, le=1.0)
    action_correct:   Optional[float] = Field(None, ge=0.0, le=1.0)
    severity_within_1: Optional[float] = Field(None, ge=0.0, le=1.0)

    @property
    def total(self) -> float:
        return sum(
            v for v in [self.label_correct, self.action_correct, self.severity_within_1]
            if v is not None
        )


# ── Step result ───────────────────────────────────────────────────────────────

class GroundTruth(BaseModel):
    """Ground truth record stored in each scenario."""
    label:     Label
    action:    ModerationAction
    severity:  Optional[int]  = Field(None, ge=1, le=5)
    rationale: Optional[str]  = None


class StepResult(BaseModel):
    """
    Full result returned by env.step().

    state   β€” next observation (next post in queue, or final post state)
    reward  β€” [-0.3, 1.0] partial-credit score (penalties may go negative)
    done    β€” False until all queue posts processed; True after final step
    info    β€” breakdown, ground truth, submitted action, warnings
    """
    state:   Observation
    reward:  float = Field(..., ge=-0.3, le=1.0)
    done:    bool
    info:    StepInfo


class StepInfo(BaseModel):
    """Metadata returned inside StepResult.info."""
    scenario_id:      str
    tier:             Tier
    ground_truth:     GroundTruth
    score_rubric:     dict
    score_breakdown:  ScoreBreakdown
    submitted_action: AgentAction
    warnings:         List[str] = Field(default_factory=list)


# ── Scenario (internal) ───────────────────────────────────────────────────────

class Scenario(BaseModel):
    """
    One benchmark scenario as stored in moderation_benchmark.json.
    Used internally by the environment loader.
    """
    id:           str
    tier:         Tier
    state:        Observation
    ground_truth: GroundTruth
    score_rubric: dict