import threading from typing import Dict, Any, Optional from .models import ContentObservation, StepResult, ResetResult, EnvState, ModerationAction from .tasks import TASKS from .graders import GRADERS class ContentModerationEnv: def __init__(self): self._lock = threading.Lock() self._s: Dict[str, Any] = {} self._clear() def _clear(self): self._s = { "task": None, "items": [], "idx": 0, "total": 0, "reward_sum": 0.0, "done": True, "history": [], } def _obs(self, item: Dict, idx: int, total: int) -> ContentObservation: return ContentObservation( content_id=item["content_id"], content_type=item["content_type"], text=item.get("text"), image_description=item.get("image_description"), detector_score=item.get("detector_score"), metadata=item.get("metadata", {}), step_num=idx, total_steps=total, ) def reset(self, task: str = "text_spam") -> ResetResult: if task not in TASKS: raise ValueError(f"Unknown task '{task}'. Valid: {list(TASKS.keys())}") with self._lock: task_cfg = TASKS[task] items = list(task_cfg["items"]) if task == "deepfake_detection": from .deepfake_model import precompute_detector_scores items = precompute_detector_scores(items) self._s = { "task": task, "items": items, "idx": 0, "total": len(items), "reward_sum": 0.0, "done": False, "history": [], } return ResetResult(observation=self._obs(items[0], 1, len(items))) def step(self, action: ModerationAction) -> StepResult: with self._lock: if self._s["done"]: return StepResult( observation=None, reward=0.0, done=True, info={"error": "Episode finished. Call /reset first."}, ) idx = self._s["idx"] item = self._s["items"][idx] task = self._s["task"] grader = GRADERS[task] action_d = action.model_dump() if task == "deepfake_detection": reward = grader(action_d, item["ground_truth"], item.get("detector_score")) else: reward = grader(action_d, item["ground_truth"]) self._s["reward_sum"] += reward self._s["idx"] += 1 self._s["history"].append({ "step": idx + 1, "content_id": item["content_id"], "action": action_d, "reward": round(reward, 4), "ground_truth": item["ground_truth"], }) new_idx = self._s["idx"] done = new_idx >= self._s["total"] self._s["done"] = done next_obs: Optional[ContentObservation] = None if not done: next_item = self._s["items"][new_idx] next_obs = self._obs(next_item, new_idx + 1, self._s["total"]) return StepResult( observation=next_obs, reward=round(reward, 4), done=done, info={"content_id": item["content_id"], "step": idx + 1}, ) def state(self) -> EnvState: with self._lock: return EnvState( task=self._s["task"] or "none", step_num=self._s["idx"], total_steps=self._s["total"], cumulative_reward=round(self._s["reward_sum"], 4), done=self._s["done"], history=list(self._s["history"]), ) def close(self): with self._lock: self._clear()