File size: 3,984 Bytes
eb0a4a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import threading
from typing import Dict, Any, Optional

from .models import ContentObservation, StepResult, ResetResult, EnvState, ModerationAction
from .tasks import TASKS
from .graders import GRADERS


class ContentModerationEnv:
    def __init__(self):
        self._lock = threading.Lock()
        self._s: Dict[str, Any] = {}
        self._clear()

    def _clear(self):
        self._s = {
            "task": None,
            "items": [],
            "idx": 0,
            "total": 0,
            "reward_sum": 0.0,
            "done": True,
            "history": [],
        }

    def _obs(self, item: Dict, idx: int, total: int) -> ContentObservation:
        return ContentObservation(
            content_id=item["content_id"],
            content_type=item["content_type"],
            text=item.get("text"),
            image_description=item.get("image_description"),
            detector_score=item.get("detector_score"),
            metadata=item.get("metadata", {}),
            step_num=idx,
            total_steps=total,
        )

    def reset(self, task: str = "text_spam") -> ResetResult:
        if task not in TASKS:
            raise ValueError(f"Unknown task '{task}'. Valid: {list(TASKS.keys())}")

        with self._lock:
            task_cfg = TASKS[task]
            items = list(task_cfg["items"])

            if task == "deepfake_detection":
                from .deepfake_model import precompute_detector_scores
                items = precompute_detector_scores(items)

            self._s = {
                "task": task,
                "items": items,
                "idx": 0,
                "total": len(items),
                "reward_sum": 0.0,
                "done": False,
                "history": [],
            }
            return ResetResult(observation=self._obs(items[0], 1, len(items)))

    def step(self, action: ModerationAction) -> StepResult:
        with self._lock:
            if self._s["done"]:
                return StepResult(
                    observation=None,
                    reward=0.0,
                    done=True,
                    info={"error": "Episode finished. Call /reset first."},
                )

            idx = self._s["idx"]
            item = self._s["items"][idx]
            task = self._s["task"]
            grader = GRADERS[task]
            action_d = action.model_dump()

            if task == "deepfake_detection":
                reward = grader(action_d, item["ground_truth"], item.get("detector_score"))
            else:
                reward = grader(action_d, item["ground_truth"])

            self._s["reward_sum"] += reward
            self._s["idx"] += 1
            self._s["history"].append({
                "step": idx + 1,
                "content_id": item["content_id"],
                "action": action_d,
                "reward": round(reward, 4),
                "ground_truth": item["ground_truth"],
            })

            new_idx = self._s["idx"]
            done = new_idx >= self._s["total"]
            self._s["done"] = done

            next_obs: Optional[ContentObservation] = None
            if not done:
                next_item = self._s["items"][new_idx]
                next_obs = self._obs(next_item, new_idx + 1, self._s["total"])

            return StepResult(
                observation=next_obs,
                reward=round(reward, 4),
                done=done,
                info={"content_id": item["content_id"], "step": idx + 1},
            )

    def state(self) -> EnvState:
        with self._lock:
            return EnvState(
                task=self._s["task"] or "none",
                step_num=self._s["idx"],
                total_steps=self._s["total"],
                cumulative_reward=round(self._s["reward_sum"], 4),
                done=self._s["done"],
                history=list(self._s["history"]),
            )

    def close(self):
        with self._lock:
            self._clear()