ethicsguard / inference.py
GodreignElgin
test
37b04ce
from __future__ import annotations
import asyncio
import json
import os
import re
from typing import Any
from openai import OpenAI
from ethicsguard.env import EthicsGuardEnv
from ethicsguard.generator import SEED_REGISTRY
from ethicsguard.models import EthicsGuardAction, EthicsGuardObservation
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-4-31B-it")
API_KEY = os.getenv("HF_TOKEN")
BENCHMARK = "ethicsguard"
TASKS = ["easy", "medium", "hard"]
MAX_STEPS = 15
SUCCESS_SCORE_THRESHOLD = 0.1
SYSTEM_PROMPT = """
You are operating a queue-triage moderation environment.
Choose exactly one item_id from the queue and one action_type from:
approve, flag_remove, escalate, skip.
Prioritize higher-tier items early, use the text when hints are missing, and do not over-escalate.
Respond as compact JSON with exactly these keys:
{"item_id":"...", "action_type":"..."}
""".strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
error_value = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_text = ",".join(f"{value:.2f}" for value in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_text}",
flush=True,
)
def build_user_prompt(observation: EthicsGuardObservation) -> str:
queue_payload = [
{
"id": item.id,
"source": item.source,
"metadata": item.metadata,
"risk_score_hint": item.risk_score_hint,
"text_snippet": item.text_snippet,
}
for item in observation.remaining_queue
]
return json.dumps(
{
"step_number": observation.step_number,
"steps_remaining": observation.steps_remaining,
"policy_summary": observation.policy_summary,
"remaining_queue": queue_payload,
},
ensure_ascii=True,
)
def _heuristic_action(observation: EthicsGuardObservation) -> EthicsGuardAction:
def sort_key(item: Any) -> tuple[float, str]:
return (item.risk_score_hint if item.risk_score_hint is not None else 0.5, item.id)
item = max(observation.remaining_queue, key=sort_key)
hint = item.risk_score_hint if item.risk_score_hint is not None else 0.5
action_type = "approve"
if hint >= 0.85:
action_type = "flag_remove"
elif hint >= 0.70:
action_type = "escalate"
elif hint >= 0.40:
action_type = "flag_remove"
return EthicsGuardAction(item_id=item.id, action_type=action_type)
def _parse_action(raw_text: str, observation: EthicsGuardObservation) -> EthicsGuardAction:
try:
return EthicsGuardAction(**json.loads(raw_text))
except Exception:
match = re.search(r'"item_id"\s*:\s*"([^"]+)".*"action_type"\s*:\s*"([^"]+)"', raw_text)
if match:
return EthicsGuardAction(item_id=match.group(1), action_type=match.group(2))
return _heuristic_action(observation)
def get_model_action(client: OpenAI | None, observation: EthicsGuardObservation) -> EthicsGuardAction:
if client is None:
return _heuristic_action(observation)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_prompt(observation)},
],
temperature=0.0,
max_tokens=120,
stream=False,
)
return _parse_action((completion.choices[0].message.content or "").strip(), observation)
except Exception:
return _heuristic_action(observation)
def _episode_action_text(action: EthicsGuardAction) -> str:
return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
async def run_task(task_name: str, seed: int) -> float:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if API_KEY else None
env = EthicsGuardEnv(difficulty=task_name, seed=seed)
result = await env.reset()
rewards: list[float] = []
steps = 0
score = 0.0
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
while not result.done and steps < MAX_STEPS:
steps += 1
action = get_model_action(client, result.observation)
result = await env.step(action)
rewards.append(result.reward)
log_step(
step=steps,
action=_episode_action_text(action),
reward=result.reward,
done=result.done,
error=result.last_action_error,
)
score = float(result.score or 0.0)
finally:
await env.close()
log_end(success=score >= SUCCESS_SCORE_THRESHOLD, steps=steps, score=score, rewards=rewards)
return score
async def main() -> None:
for task_name in TASKS:
eval_seeds = SEED_REGISTRY[task_name]["eval"]
seed = eval_seeds[0]
await run_task(task_name, seed)
if __name__ == "__main__":
asyncio.run(main())