Spaces:
Sleeping
Sleeping
| import threading | |
| from typing import Dict, Any, Optional | |
| from .models import ContentObservation, StepResult, ResetResult, EnvState, ModerationAction | |
| from .tasks import TASKS | |
| from .graders import GRADERS | |
| class ContentModerationEnv: | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self._s: Dict[str, Any] = {} | |
| self._clear() | |
| def _clear(self): | |
| self._s = { | |
| "task": None, | |
| "items": [], | |
| "idx": 0, | |
| "total": 0, | |
| "reward_sum": 0.0, | |
| "done": True, | |
| "history": [], | |
| } | |
| def _obs(self, item: Dict, idx: int, total: int) -> ContentObservation: | |
| return ContentObservation( | |
| content_id=item["content_id"], | |
| content_type=item["content_type"], | |
| text=item.get("text"), | |
| image_description=item.get("image_description"), | |
| detector_score=item.get("detector_score"), | |
| metadata=item.get("metadata", {}), | |
| step_num=idx, | |
| total_steps=total, | |
| ) | |
| def reset(self, task: str = "text_spam") -> ResetResult: | |
| if task not in TASKS: | |
| raise ValueError(f"Unknown task '{task}'. Valid: {list(TASKS.keys())}") | |
| with self._lock: | |
| task_cfg = TASKS[task] | |
| items = list(task_cfg["items"]) | |
| if task == "deepfake_detection": | |
| from .deepfake_model import precompute_detector_scores | |
| items = precompute_detector_scores(items) | |
| self._s = { | |
| "task": task, | |
| "items": items, | |
| "idx": 0, | |
| "total": len(items), | |
| "reward_sum": 0.0, | |
| "done": False, | |
| "history": [], | |
| } | |
| return ResetResult(observation=self._obs(items[0], 1, len(items))) | |
| def step(self, action: ModerationAction) -> StepResult: | |
| with self._lock: | |
| if self._s["done"]: | |
| return StepResult( | |
| observation=None, | |
| reward=0.0, | |
| done=True, | |
| info={"error": "Episode finished. Call /reset first."}, | |
| ) | |
| idx = self._s["idx"] | |
| item = self._s["items"][idx] | |
| task = self._s["task"] | |
| grader = GRADERS[task] | |
| action_d = action.model_dump() | |
| if task == "deepfake_detection": | |
| reward = grader(action_d, item["ground_truth"], item.get("detector_score")) | |
| else: | |
| reward = grader(action_d, item["ground_truth"]) | |
| self._s["reward_sum"] += reward | |
| self._s["idx"] += 1 | |
| self._s["history"].append({ | |
| "step": idx + 1, | |
| "content_id": item["content_id"], | |
| "action": action_d, | |
| "reward": round(reward, 4), | |
| "ground_truth": item["ground_truth"], | |
| }) | |
| new_idx = self._s["idx"] | |
| done = new_idx >= self._s["total"] | |
| self._s["done"] = done | |
| next_obs: Optional[ContentObservation] = None | |
| if not done: | |
| next_item = self._s["items"][new_idx] | |
| next_obs = self._obs(next_item, new_idx + 1, self._s["total"]) | |
| return StepResult( | |
| observation=next_obs, | |
| reward=round(reward, 4), | |
| done=done, | |
| info={"content_id": item["content_id"], "step": idx + 1}, | |
| ) | |
| def state(self) -> EnvState: | |
| with self._lock: | |
| return EnvState( | |
| task=self._s["task"] or "none", | |
| step_num=self._s["idx"], | |
| total_steps=self._s["total"], | |
| cumulative_reward=round(self._s["reward_sum"], 4), | |
| done=self._s["done"], | |
| history=list(self._s["history"]), | |
| ) | |
| def close(self): | |
| with self._lock: | |
| self._clear() | |