File size: 7,653 Bytes
eb0a4a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833bc29
 
eb0a4a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
059feb0
eb0a4a1
 
 
059feb0
eb0a4a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Content Moderation Inference Script
Env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN, SERVER_URL, TASK_NAME
"""

import json
import os
import textwrap
from typing import Dict, Any, List, Optional

from dotenv import load_dotenv
import requests

load_dotenv()

API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN or os.getenv("API_KEY")
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:7860").rstrip("/")
TASK_NAME = os.getenv("TASK_NAME", "text_spam")
BENCHMARK = "content_moderation_env"
MAX_STEPS = 10
USE_LOCAL = os.getenv("USE_LOCAL_MODEL", "false").lower() == "true"
SUCCESS_SCORE_THRESHOLD = 0.5

VALID_DECISIONS = {"approve", "reject", "escalate", "flag"}
VALID_LABELS = [
    "spam", "scam", "phishing", "pharmaceutical_spam",
    "hate_speech", "violence", "harassment", "misinformation",
    "adult_content", "deepfake", "political_manipulation", "fraud",
]

SYSTEM_PROMPT = textwrap.dedent("""
You are an expert content moderation AI. For each piece of content, analyze it carefully and respond ONLY with a valid JSON object.

Required JSON fields:
- "decision": one of "approve", "reject", "escalate", "flag"
- "reason": brief explanation (max 2 sentences)
- "confidence": float between 0.0 and 1.0
- "labels": list of applicable labels

Decision guidelines:
- approve: content is safe and legitimate
- reject: content clearly violates policy
- flag: content is borderline or needs human review
- escalate: uncertain, requires specialist review

Output ONLY the JSON object, no other text.
""").strip()


def log_start(task: str, env: str, model: str):
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]):
    err = error if error else "null"
    print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={err}", flush=True)


def log_end(success: bool, steps: int, score: float, rewards: List[float]):
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)


def build_prompt(obs: Dict[str, Any]) -> str:
    parts = [f"Content ID: {obs.get('content_id', 'unknown')}"]
    parts.append(f"Type: {obs.get('content_type', 'text')}")

    if obs.get("text"):
        parts.append(f"Text: {obs['text']}")
    if obs.get("image_description"):
        parts.append(f"Image analysis: {obs['image_description']}")
    if obs.get("detector_score") is not None:
        score = obs["detector_score"]
        parts.append(f"Deepfake detector score (higher = more likely fake): {score:.3f}")

    meta = obs.get("metadata", {})
    if meta:
        meta_str = ", ".join(f"{k}={v}" for k, v in meta.items())
        parts.append(f"Metadata: {meta_str}")

    parts.append(f"\nStep {obs.get('step_num', '?')} of {obs.get('total_steps', '?')}")
    return "\n".join(parts)


def _default_action() -> Dict:
    return {"decision": "escalate", "reason": "Unable to analyze content.", "confidence": 0.3, "labels": []}


def call_local_model(prompt: str) -> Dict:
    from transformers import pipeline

    pipe = pipeline("text-generation", model="meta-llama/Llama-3.1-8B-Instruct")
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": prompt},
    ]
    output = pipe(messages, max_new_tokens=256, temperature=0.2, do_sample=True)
    text = output[0]["generated_text"]
    if isinstance(text, list):
        text = text[-1].get("content", "")
    return parse_llm_response(text)


def call_api_model(prompt: str) -> Dict:
    from openai import OpenAI

    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY or "hf_default")
    completion = client.chat.completions.create(
        model=MODEL_NAME,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
        ],
        temperature=0.2,
        max_tokens=256,
    )
    text = (completion.choices[0].message.content or "").strip()
    return parse_llm_response(text)


def parse_llm_response(text: str) -> Dict:
    try:
        start = text.find("{")
        end = text.rfind("}") + 1
        if start >= 0 and end > start:
            parsed = json.loads(text[start:end])
            decision = parsed.get("decision", "escalate")
            if decision not in VALID_DECISIONS:
                decision = "escalate"
            return {
                "decision": decision,
                "reason": str(parsed.get("reason", ""))[:200],
                "confidence": float(max(0.0, min(1.0, parsed.get("confidence", 0.5)))),
                "labels": [l for l in parsed.get("labels", []) if l in VALID_LABELS],
            }
    except Exception:
        pass
    return _default_action()


def get_decision(prompt: str) -> Dict:
    try:
        if USE_LOCAL:
            return call_local_model(prompt)
        return call_api_model(prompt)
    except Exception as e:
        print(f"[DEBUG] Model error: {e}", flush=True)
        return _default_action()


def server_reset(task: str) -> Optional[Dict]:
    try:
        r = requests.post(f"{SERVER_URL}/reset", json={"task": task}, timeout=30)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"[DEBUG] reset error: {e}", flush=True)
        return None


def server_step(action: Dict) -> Optional[Dict]:
    try:
        r = requests.post(f"{SERVER_URL}/step", json=action, timeout=30)
        r.raise_for_status()
        return r.json()
    except Exception as e:
        print(f"[DEBUG] step error: {e}", flush=True)
        return None


def server_close():
    try:
        requests.post(f"{SERVER_URL}/close", timeout=10)
    except Exception:
        pass


def run_episode(task: str):
    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False
    obs = None

    log_start(task=task, env=BENCHMARK, model=MODEL_NAME)

    try:
        reset_result = server_reset(task)
        if reset_result is None:
            log_end(success=False, steps=0, score=0.0, rewards=[])
            return

        obs = reset_result.get("observation", {})
        done = False

        for step in range(1, MAX_STEPS + 1):
            if done or obs is None:
                break

            prompt = build_prompt(obs)
            action = get_decision(prompt)
            action_str = json.dumps({k: v for k, v in action.items() if k != "reason"})

            result = server_step(action)
            if result is None:
                log_step(step, action_str, 0.0, True, "server_error")
                break

            reward = float(result.get("reward", 0.0))
            done = bool(result.get("done", False))
            error = result.get("info", {}).get("error")

            rewards.append(reward)
            steps_taken = step

            log_step(step, action_str, reward, done, error)

            obs = result.get("observation")

        total_steps_in_task = obs.get("total_steps", len(rewards)) if obs else len(rewards)
        max_possible = float(total_steps_in_task)
        score = sum(rewards) / max_possible if max_possible > 0 else 0.0
        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD

    finally:
        server_close()
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)


if __name__ == "__main__":
    run_episode(TASK_NAME)