Spaces:
Sleeping
Sleeping
| """ | |
| models.py | |
| ========= | |
| Typed Pydantic data-models for the ContentModerationEnv OpenEnv spec. | |
| These models are used for: | |
| β’ runtime validation of agent actions | |
| β’ serialisation of environment state | |
| β’ documentation (JSON Schema derivable via .model_json_schema()) | |
| Requires: pydantic >= 2.0 | |
| """ | |
| from __future__ import annotations | |
| from enum import Enum | |
| from typing import List, Optional | |
| from pydantic import BaseModel, Field, field_validator, model_validator | |
| # ββ Enumerations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Label(str, Enum): | |
| """Content classification label.""" | |
| safe = "safe" | |
| toxic = "toxic" | |
| spam = "spam" | |
| misleading = "misleading" | |
| class ModerationAction(str, Enum): | |
| """Moderation action to apply to the content.""" | |
| allow = "allow" | |
| warn = "warn" | |
| remove = "remove" | |
| shadowban = "shadowban" | |
| escalate = "escalate" | |
| class PlatformPolicy(str, Enum): | |
| """Enforcement level of the platform.""" | |
| strict = "strict" | |
| moderate = "moderate" | |
| lenient = "lenient" | |
| class Platform(str, Enum): | |
| """Social media platform context.""" | |
| reddit = "reddit" | |
| twitter = "twitter" | |
| youtube = "youtube" | |
| linkedin = "linkedin" | |
| class Tier(str, Enum): | |
| """Benchmark difficulty tier.""" | |
| easy = "easy" | |
| medium = "medium" | |
| hard = "hard" | |
| # ββ Observation (state) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Observation(BaseModel): | |
| """ | |
| The environment's observation returned by reset() / state(). | |
| Attributes | |
| ---------- | |
| text : str | |
| The user-generated text content to be reviewed. | |
| audio_transcript : str | None | |
| Transcript of any accompanying audio/video (None for text-only). | |
| visual_tags : list[str] | |
| Machine-detected visual content tags (empty list if no visual media). | |
| previous_flags : int | |
| Number of prior platform violations by this account (β₯ 0). | |
| platform_policy : PlatformPolicy | |
| Policy enforcement level the moderation decision must respect. | |
| platform : str | None | |
| Social media platform context: reddit / twitter / youtube / linkedin. | |
| None for legacy scenarios that predate the platform field. | |
| """ | |
| text: str = Field(..., description="User-generated content text") | |
| audio_transcript: Optional[str] = Field(None, description="Audio/video transcript (nullable)") | |
| visual_tags: List[str] = Field(default_factory=list, description="Detected visual content tags") | |
| previous_flags: int = Field(..., ge=0, description="Prior policy violations count") | |
| platform_policy: PlatformPolicy = Field(..., description="Platform enforcement level") | |
| platform: Optional[str] = Field( | |
| None, | |
| description="Platform: reddit/twitter/youtube/linkedin" | |
| ) | |
| model_config = {"frozen": True} # immutable β agents must not mutate state | |
| # ββ Action ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AgentAction(BaseModel): | |
| """ | |
| The action an agent submits via env.step(). | |
| Required | |
| -------- | |
| label : Label | |
| action : ModerationAction | |
| Optional (scored only in hard tier) | |
| ------------------------------------ | |
| severity : int in [1, 5] | |
| rationale : str | |
| """ | |
| label: Label = Field(..., description="Content classification") | |
| action: ModerationAction = Field(..., description="Moderation action to apply") | |
| severity: Optional[int] = Field(None, ge=1, le=5, description="Severity 1-5 (hard tier only)") | |
| rationale: Optional[str] = Field(None, description="Brief reasoning (not scored)") | |
| def coerce_severity(cls, v): | |
| """Accept string integers gracefully.""" | |
| if v is not None: | |
| return int(v) | |
| return v | |
| def to_env_dict(self) -> dict: | |
| """Convert to the plain dict format expected by ContentModerationEnv.step().""" | |
| d: dict = {"label": self.label.value, "action": self.action.value} | |
| if self.severity is not None: | |
| d["severity"] = self.severity | |
| if self.rationale is not None: | |
| d["rationale"] = self.rationale | |
| return d | |
| # ββ Score breakdown βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ScoreBreakdown(BaseModel): | |
| """Per-component reward breakdown returned in step() info dict.""" | |
| label_correct: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| action_correct: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| severity_within_1: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| def total(self) -> float: | |
| return sum( | |
| v for v in [self.label_correct, self.action_correct, self.severity_within_1] | |
| if v is not None | |
| ) | |
| # ββ Step result βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GroundTruth(BaseModel): | |
| """Ground truth record stored in each scenario.""" | |
| label: Label | |
| action: ModerationAction | |
| severity: Optional[int] = Field(None, ge=1, le=5) | |
| rationale: Optional[str] = None | |
| class StepResult(BaseModel): | |
| """ | |
| Full result returned by env.step(). | |
| state β next observation (next post in queue, or final post state) | |
| reward β [-0.3, 1.0] partial-credit score (penalties may go negative) | |
| done β False until all queue posts processed; True after final step | |
| info β breakdown, ground truth, submitted action, warnings | |
| """ | |
| state: Observation | |
| reward: float = Field(..., ge=-0.3, le=1.0) | |
| done: bool | |
| info: StepInfo | |
| class StepInfo(BaseModel): | |
| """Metadata returned inside StepResult.info.""" | |
| scenario_id: str | |
| tier: Tier | |
| ground_truth: GroundTruth | |
| score_rubric: dict | |
| score_breakdown: ScoreBreakdown | |
| submitted_action: AgentAction | |
| warnings: List[str] = Field(default_factory=list) | |
| # ββ Scenario (internal) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Scenario(BaseModel): | |
| """ | |
| One benchmark scenario as stored in moderation_benchmark.json. | |
| Used internally by the environment loader. | |
| """ | |
| id: str | |
| tier: Tier | |
| state: Observation | |
| ground_truth: GroundTruth | |
| score_rubric: dict | |