from __future__ import annotations import asyncio import json import os import re from typing import Any from openai import OpenAI from ethicsguard.env import EthicsGuardEnv from ethicsguard.generator import SEED_REGISTRY from ethicsguard.models import EthicsGuardAction, EthicsGuardObservation API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31B-it") API_KEY = os.getenv("HF_TOKEN") BENCHMARK = "ethicsguard" TASKS = ["easy", "medium", "hard"] MAX_STEPS = 15 SUCCESS_SCORE_THRESHOLD = 0.1 SYSTEM_PROMPT = """ You are operating a queue-triage moderation environment. Choose exactly one item_id from the queue and one action_type from: approve, flag_remove, escalate, skip. Prioritize higher-tier items early, use the text when hints are missing, and do not over-escalate. Respond as compact JSON with exactly these keys: {"item_id":"...", "action_type":"..."} """.strip() def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None: error_value = error if error else "null" print( f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_text = ",".join(f"{value:.2f}" for value in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_text}", flush=True, ) def build_user_prompt(observation: EthicsGuardObservation) -> str: queue_payload = [ { "id": item.id, "source": item.source, "metadata": item.metadata, "risk_score_hint": item.risk_score_hint, "text_snippet": item.text_snippet, } for item in observation.remaining_queue ] return json.dumps( { "step_number": observation.step_number, "steps_remaining": observation.steps_remaining, "policy_summary": observation.policy_summary, "remaining_queue": queue_payload, }, ensure_ascii=True, ) def _heuristic_action(observation: EthicsGuardObservation) -> EthicsGuardAction: def sort_key(item: Any) -> tuple[float, str]: return (item.risk_score_hint if item.risk_score_hint is not None else 0.5, item.id) item = max(observation.remaining_queue, key=sort_key) hint = item.risk_score_hint if item.risk_score_hint is not None else 0.5 action_type = "approve" if hint >= 0.85: action_type = "flag_remove" elif hint >= 0.70: action_type = "escalate" elif hint >= 0.40: action_type = "flag_remove" return EthicsGuardAction(item_id=item.id, action_type=action_type) def _parse_action(raw_text: str, observation: EthicsGuardObservation) -> EthicsGuardAction: try: return EthicsGuardAction(**json.loads(raw_text)) except Exception: match = re.search(r'"item_id"\s*:\s*"([^"]+)".*"action_type"\s*:\s*"([^"]+)"', raw_text) if match: return EthicsGuardAction(item_id=match.group(1), action_type=match.group(2)) return _heuristic_action(observation) def get_model_action(client: OpenAI | None, observation: EthicsGuardObservation) -> EthicsGuardAction: if client is None: return _heuristic_action(observation) try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_user_prompt(observation)}, ], temperature=0.0, max_tokens=120, stream=False, ) return _parse_action((completion.choices[0].message.content or "").strip(), observation) except Exception: return _heuristic_action(observation) def _episode_action_text(action: EthicsGuardAction) -> str: return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True) async def run_task(task_name: str, seed: int) -> float: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None env = EthicsGuardEnv(difficulty=task_name, seed=seed) result = await env.reset() rewards: list[float] = [] steps = 0 score = 0.0 log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME) try: while not result.done and steps < MAX_STEPS: steps += 1 action = get_model_action(client, result.observation) result = await env.step(action) rewards.append(result.reward) log_step( step=steps, action=_episode_action_text(action), reward=result.reward, done=result.done, error=result.last_action_error, ) score = float(result.score or 0.0) finally: await env.close() log_end(success=score >= SUCCESS_SCORE_THRESHOLD, steps=steps, score=score, rewards=rewards) return score async def main() -> None: for task_name in TASKS: eval_seeds = SEED_REGISTRY[task_name]["eval"] seed = eval_seeds[0] await run_task(task_name, seed) if __name__ == "__main__": asyncio.run(main())