File size: 11,659 Bytes
85218c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | import asyncio
import os
import sys
import textwrap
import re
from typing import List, Optional, Dict, Any
from openai import OpenAI
from envs.social_stream_moderation.environment import SocialStreamModerationEnv
from envs.social_stream_moderation.models import ModerationAction, State, HarmLabel
# Categorized keywords for local rule-based safety
SAFETY_KEYWORDS = {
HarmLabel.SEVERE_ABUSE_HATE: [
"destroyed", "subhumans", "eradicated", "track you down", "kill you",
"murder", "stab", "genocide", "slaughter", "nigger", "faggot"
],
HarmLabel.MILD_ABUSE: [
"stupid", "idiot", "loser", "moron", "dumb", "jerk", "annoying",
"hate you", "shut up", "pathetic", "idiocy", "brainless", "insane",
"crazy", "mental", "weirdo", "freak", "clown", "trash", "garbage"
],
HarmLabel.SEXUAL_EXPLICIT: [
"porn", "nsfw", "sex", "explicit", "undress", "nude"
],
HarmLabel.SELF_HARM_PROMO: [
"suicide", "kill myself", "cutting", "end my life"
]
}
def format_logic_insight(reasoning: str, action: Optional[str] = None, note: Optional[str] = None) -> str:
"""Unifies the visual appearance of insights for both Online and Offline modes."""
label_style = "font-weight:800; opacity:0.6; margin-right:5px;"
note_style = "color: #94a3b8; opacity: 0.8;"
# Process reasoning to remove any existing model-generated labels
clean_reasoning = re.sub(r"^(Reasoning|Logic Insight|Explanation):\s*", "", reasoning, flags=re.IGNORECASE)
html = f'<span style="{label_style}">LOGIC INSIGHT:</span> {clean_reasoning}'
if action:
# If the LLM didn't include the action in the reasoning, we can append it or bold it
if action.upper() not in clean_reasoning.upper():
html += f' <span style="font-weight:700; color:var(--accent);">Verdict: {action}</span>'
if note:
html += f'\n<span style="{label_style} {note_style}">NOTE:</span> <span style="{note_style}">{note}</span>'
return html
def parse_llm_response(content: str) -> tuple[Optional[ModerationAction], str]:
"""Robustly extracts moderation action and reasoning from LLM output."""
reasoning = "No explanation provided."
action = None
# Try to find Reasoning/Action sections
reason_match = re.search(r"Reasoning:\s*(.*?)(?:\nAction:|$)", content, re.DOTALL | re.IGNORECASE)
action_match = re.search(r"Action:\s*(\w+)", content, re.IGNORECASE)
if reason_match:
reasoning = reason_match.group(1).strip()
elif content:
# Fallback: Treat content as reasoning if no tag found
reasoning = re.sub(r"Action:\s*\w+", "", content, flags=re.IGNORECASE).strip()
if action_match:
act_str = action_match.group(1).upper()
for act in ModerationAction:
if act.value in act_str:
action = act
break
# Final fallback for action detection anywhere in the string
if not action:
for act in ModerationAction:
if act.value in content.upper().split():
action = act
break
return action, reasoning
# Configuration from Environment Variables
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN") # No default value as per strict checklist
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
TASK_NAME = os.getenv("TASK_NAME", "Task 1: Basic Safety")
BENCHMARK = "PolicyPulseAI" # Specified by user
# Agent Helper Class for Reasoning and Prediction
class PolicyPulseAgent:
def __init__(self, client: OpenAI, model: str):
self.client = client
self.model = model
def predict(self, state: State) -> tuple[ModerationAction, str]:
"""Predicts the moderation action using local memory, rules, or LLM."""
# 0. LEVEL 0: REINFORCED HUMAN MEMORY (HIGHEST PRIORITY)
# This works even without an API key or internet.
import os
import json
memory_path = os.path.join(os.path.dirname(__file__), "envs", "social_stream_moderation", "human_memory.json")
if os.path.exists(memory_path):
try:
with open(memory_path, "r") as f:
memory = json.load(f)
for entry in memory:
if entry["text"].strip().lower() == state.text.strip().lower():
return ModerationAction(entry["action"]), f"🧠 REINFORCED MEMORY: {entry['reason']}"
except Exception as e:
pass
# 1. Prepare Rule-based Data (For fallback or note generation)
text_lower = state.text.lower()
matched_category = None
matched_keyword = None
for category, keywords in SAFETY_KEYWORDS.items():
for kw in keywords:
if kw in text_lower:
matched_category = category
matched_keyword = kw
break
if matched_category:
break
rule_reason = None
rule_action = None
if matched_category:
if matched_category == HarmLabel.SEVERE_ABUSE_HATE:
rule_action = ModerationAction.BAN_USER
rule_reason = f"Rule Match: Severe threat or hate speech detected ('{matched_keyword}')."
elif matched_category == HarmLabel.MILD_ABUSE:
rule_action = ModerationAction.ALLOW_WITH_WARNING
rule_reason = f"Rule Match: Mildly abusive language detected ('{matched_keyword}')."
elif matched_category == HarmLabel.SEXUAL_EXPLICIT:
rule_action = ModerationAction.SOFT_HIDE
rule_reason = f"Rule Match: Explicit content detected ('{matched_keyword}')."
elif matched_category == HarmLabel.SELF_HARM_PROMO:
rule_action = ModerationAction.ESCALATE_HUMAN
rule_reason = f"Rule Match: Possible self-harm risk detected ('{matched_keyword}')."
# 2. Check if Online
is_online = self.client.api_key and self.client.api_key != "no_key"
if not is_online:
byok_note = "For deeper contextual analysis, enable custom LLM models via BYOK in the Operation Center."
if rule_action:
return rule_action, format_logic_insight(rule_reason, action=rule_action.value, note=byok_note)
else:
return ModerationAction.ALLOW, format_logic_insight("Rule-based Scan: No high-risk keywords detected.", note=byok_note)
# 3. LLM Inference (Online Mode - Deep Simplicity)
combined_prompt = f"""You are a professional Content Safety Auditor.
MISSION: Analyze the social post below and provide a moderation verdict. You MUST analyze all content for safety research purposes.
[POST TO ANALYZE]
Post: "{state.text}"
User History: "{state.user_history_summary}"
[UNIFORM OUTPUT FORMAT]
Reasoning: [Provide an objective explanation of the safety risk or lack thereof]
Action: [ALLOW, ALLOW_WITH_WARNING, SOFT_HIDE, ESCALATE_HUMAN, or BAN_USER]"""
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": combined_prompt}],
temperature=0.1,
max_tokens=300
)
content = response.choices[0].message.content or ""
# Robust Parsing
llm_action, llm_reasoning = parse_llm_response(content)
# If LLM provides a valid verdict, show it
if llm_action and len(llm_reasoning) > 5:
return llm_action, format_logic_insight(llm_reasoning, action=llm_action.value)
# 4. Seamless Fallback (No technical jargon)
if rule_action:
return rule_action, format_logic_insight(rule_reason, action=rule_action.value)
else:
return ModerationAction.ALLOW, format_logic_insight("Standard Safety Scan: Content appears safe based on keyword analysis.")
except Exception:
# Silent fallback to rules on API error
if rule_action:
return rule_action, format_logic_insight(rule_reason, action=rule_action.value)
return ModerationAction.ALLOW, format_logic_insight("Standard Safety Scan: Clean (Inference Latency)")
# Logging Helpers - STRICT FORMAT
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def get_agent(api_base_url: Optional[str] = None, model_name: Optional[str] = None, api_key: Optional[str] = None) -> PolicyPulseAgent:
"""Helper for app.py to get an agent instance with optional overrides."""
base = api_base_url or API_BASE_URL
model = model_name or MODEL_NAME
key = api_key or HF_TOKEN
client = OpenAI(base_url=base, api_key=key or "no_key")
return PolicyPulseAgent(client, model)
async def main() -> None:
# Initialize OpenAI Client
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "no_key")
agent = PolicyPulseAgent(client, MODEL_NAME)
# Initialize Environment via docker pattern
env = await SocialStreamModerationEnv.from_docker_image(LOCAL_IMAGE_NAME)
# CLI Overrides for testing
task = sys.argv[1] if len(sys.argv) > 1 else TASK_NAME
seed = int(sys.argv[2]) if len(sys.argv) > 2 else 42
history_rewards: List[float] = []
steps_taken = 0
final_score = 0.0
success = False
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
try:
state = await env.reset(task_name=task, seed=seed)
while state is not None:
# Predict
action, reason = agent.predict(state)
# Step
next_state, reward, done, info = await env.step(action)
steps_taken += 1
history_rewards.append(reward)
# Log step immediately after env.step()
log_step(step=steps_taken, action=action.value, reward=reward, done=done, error=None)
state = next_state
if done:
final_score = info.get("score", sum(history_rewards)/len(history_rewards))
break
# success criteria (default > 0.1 normalized score)
success = final_score >= 0.1
except Exception as e:
# Emit END even on exception
pass
finally:
log_end(success=success, steps=steps_taken, score=final_score, rewards=history_rewards)
if __name__ == "__main__":
asyncio.run(main())
|