ANI00's picture
first commit
eb0a4a1
import threading
from typing import Dict, Any, Optional
from .models import ContentObservation, StepResult, ResetResult, EnvState, ModerationAction
from .tasks import TASKS
from .graders import GRADERS
class ContentModerationEnv:
def __init__(self):
self._lock = threading.Lock()
self._s: Dict[str, Any] = {}
self._clear()
def _clear(self):
self._s = {
"task": None,
"items": [],
"idx": 0,
"total": 0,
"reward_sum": 0.0,
"done": True,
"history": [],
}
def _obs(self, item: Dict, idx: int, total: int) -> ContentObservation:
return ContentObservation(
content_id=item["content_id"],
content_type=item["content_type"],
text=item.get("text"),
image_description=item.get("image_description"),
detector_score=item.get("detector_score"),
metadata=item.get("metadata", {}),
step_num=idx,
total_steps=total,
)
def reset(self, task: str = "text_spam") -> ResetResult:
if task not in TASKS:
raise ValueError(f"Unknown task '{task}'. Valid: {list(TASKS.keys())}")
with self._lock:
task_cfg = TASKS[task]
items = list(task_cfg["items"])
if task == "deepfake_detection":
from .deepfake_model import precompute_detector_scores
items = precompute_detector_scores(items)
self._s = {
"task": task,
"items": items,
"idx": 0,
"total": len(items),
"reward_sum": 0.0,
"done": False,
"history": [],
}
return ResetResult(observation=self._obs(items[0], 1, len(items)))
def step(self, action: ModerationAction) -> StepResult:
with self._lock:
if self._s["done"]:
return StepResult(
observation=None,
reward=0.0,
done=True,
info={"error": "Episode finished. Call /reset first."},
)
idx = self._s["idx"]
item = self._s["items"][idx]
task = self._s["task"]
grader = GRADERS[task]
action_d = action.model_dump()
if task == "deepfake_detection":
reward = grader(action_d, item["ground_truth"], item.get("detector_score"))
else:
reward = grader(action_d, item["ground_truth"])
self._s["reward_sum"] += reward
self._s["idx"] += 1
self._s["history"].append({
"step": idx + 1,
"content_id": item["content_id"],
"action": action_d,
"reward": round(reward, 4),
"ground_truth": item["ground_truth"],
})
new_idx = self._s["idx"]
done = new_idx >= self._s["total"]
self._s["done"] = done
next_obs: Optional[ContentObservation] = None
if not done:
next_item = self._s["items"][new_idx]
next_obs = self._obs(next_item, new_idx + 1, self._s["total"])
return StepResult(
observation=next_obs,
reward=round(reward, 4),
done=done,
info={"content_id": item["content_id"], "step": idx + 1},
)
def state(self) -> EnvState:
with self._lock:
return EnvState(
task=self._s["task"] or "none",
step_num=self._s["idx"],
total_steps=self._s["total"],
cumulative_reward=round(self._s["reward_sum"], 4),
done=self._s["done"],
history=list(self._s["history"]),
)
def close(self):
with self._lock:
self._clear()