Phase 2 complete: Fixed inference loop and added phase gates
Browse files- apps/audit_api.py +23 -0
- apps/crm_api.py +20 -0
- apps/regulatory_api.py +30 -0
- apps/start_all.bat +10 -0
- grpo_train.py +56 -0
- inference.py +54 -26
- server/__pycache__/__init__.cpython-313.pyc +0 -0
- server/__pycache__/app.cpython-313.pyc +0 -0
- src/__pycache__/environment.cpython-313.pyc +0 -0
- src/__pycache__/generator.cpython-313.pyc +0 -0
- src/environment.py +88 -42
- src/meta_ad_policy_sandbox.egg-info/PKG-INFO +10 -0
- src/meta_ad_policy_sandbox.egg-info/SOURCES.txt +12 -0
- src/meta_ad_policy_sandbox.egg-info/dependency_links.txt +1 -0
- src/meta_ad_policy_sandbox.egg-info/entry_points.txt +2 -0
- src/meta_ad_policy_sandbox.egg-info/requires.txt +6 -0
- src/meta_ad_policy_sandbox.egg-info/top_level.txt +4 -0
- uv.lock +0 -0
apps/audit_api.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import uvicorn
|
| 4 |
+
|
| 5 |
+
app = FastAPI(title="Compliance Audit API")
|
| 6 |
+
logs = []
|
| 7 |
+
|
| 8 |
+
class AuditRecord(BaseModel):
|
| 9 |
+
ad_id: str
|
| 10 |
+
action_taken: str
|
| 11 |
+
reasoning: str
|
| 12 |
+
|
| 13 |
+
@app.post("/log")
|
| 14 |
+
def log_audit(record: AuditRecord):
|
| 15 |
+
logs.append(record.dict())
|
| 16 |
+
return {"status": "success", "audit_id": f"AUD-{len(logs)}"}
|
| 17 |
+
|
| 18 |
+
@app.get("/health")
|
| 19 |
+
def health():
|
| 20 |
+
return {"status": "ok", "service": "compliance-audit"}
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
uvicorn.run(app, host="0.0.0.0", port=8003)
|
apps/crm_api.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
import uvicorn
|
| 3 |
+
|
| 4 |
+
app = FastAPI(title="Advertiser CRM API")
|
| 5 |
+
|
| 6 |
+
ADVERTISERS = {
|
| 7 |
+
"adv_001": {"prior_violations": 3, "summary": "High-risk repeat offender."},
|
| 8 |
+
"adv_002": {"prior_violations": 0, "summary": "Established clean record."}
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
@app.get("/advertiser/{advertiser_id}")
|
| 12 |
+
def get_advertiser(advertiser_id: str):
|
| 13 |
+
return ADVERTISERS.get(advertiser_id, {"prior_violations": 0, "summary": "New advertiser."})
|
| 14 |
+
|
| 15 |
+
@app.get("/health")
|
| 16 |
+
def health():
|
| 17 |
+
return {"status": "ok", "service": "advertiser-crm"}
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
uvicorn.run(app, host="0.0.0.0", port=8002)
|
apps/regulatory_api.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
import uvicorn
|
| 3 |
+
|
| 4 |
+
app = FastAPI(title="Regulatory DB API")
|
| 5 |
+
|
| 6 |
+
REGULATIONS = {
|
| 7 |
+
"healthcare": {
|
| 8 |
+
"policy_summary": "Claims require FDA approval. No 'guaranteed results' allowed.",
|
| 9 |
+
"risk_level": "high"
|
| 10 |
+
},
|
| 11 |
+
"financial": {
|
| 12 |
+
"policy_summary": "Requires SEC registration. Prohibited: predatory APR > 36%.",
|
| 13 |
+
"risk_level": "high"
|
| 14 |
+
},
|
| 15 |
+
"general": {
|
| 16 |
+
"policy_summary": "Standard standards apply. No deceptive claims.",
|
| 17 |
+
"risk_level": "low"
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@app.get("/regulations/{category}")
|
| 22 |
+
def get_regulations(category: str):
|
| 23 |
+
return REGULATIONS.get(category.lower(), REGULATIONS["general"])
|
| 24 |
+
|
| 25 |
+
@app.get("/health")
|
| 26 |
+
def health():
|
| 27 |
+
return {"status": "ok", "service": "regulatory-db"}
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
apps/start_all.bat
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Launching Enterprise Ecosystem...
|
| 3 |
+
|
| 4 |
+
:: This line forces Windows to go to the project root before running anything
|
| 5 |
+
cd /d "%~dp0\.."
|
| 6 |
+
|
| 7 |
+
start "Regulatory API" cmd /k "uv run python apps\regulatory_api.py"
|
| 8 |
+
start "CRM API" cmd /k "uv run python apps\crm_api.py"
|
| 9 |
+
start "Audit API" cmd /k "uv run python apps\audit_api.py"
|
| 10 |
+
start "Environment Server" cmd /k "uv run uvicorn server.app:app --host 0.0.0.0 --port 8000"
|
grpo_train.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 3 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 4 |
+
from src.environment import AdPolicyEnvironment
|
| 5 |
+
|
| 6 |
+
# 1. Load Model with Unsloth
|
| 7 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 8 |
+
model_name = "unsloth/Llama-3.1-8B-Instruct",
|
| 9 |
+
max_seq_length = 1024,
|
| 10 |
+
load_in_4bit = True,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
# 2. Define Reward Functions
|
| 14 |
+
def reward_compliance(prompts, completions, **kwargs):
|
| 15 |
+
rewards = []
|
| 16 |
+
for completion in completions:
|
| 17 |
+
# Check if the model called the necessary tools in order
|
| 18 |
+
if "query_regulations" in completion and "submit_audit" in completion:
|
| 19 |
+
rewards.append(2.0)
|
| 20 |
+
else:
|
| 21 |
+
rewards.append(0.0)
|
| 22 |
+
return rewards
|
| 23 |
+
|
| 24 |
+
def reward_json_format(prompts, completions, **kwargs):
|
| 25 |
+
rewards = []
|
| 26 |
+
for completion in completions:
|
| 27 |
+
try:
|
| 28 |
+
import json
|
| 29 |
+
json.loads(completion)
|
| 30 |
+
rewards.append(1.0)
|
| 31 |
+
except:
|
| 32 |
+
rewards.append(0.0)
|
| 33 |
+
return rewards
|
| 34 |
+
|
| 35 |
+
# 3. Configure Trainer
|
| 36 |
+
training_args = GRPOConfig(
|
| 37 |
+
output_dir = "outputs/meta-ad-agent",
|
| 38 |
+
learning_rate = 5e-6,
|
| 39 |
+
num_train_epochs = 1,
|
| 40 |
+
per_device_train_batch_size = 4,
|
| 41 |
+
gradient_accumulation_steps = 4,
|
| 42 |
+
max_prompt_length = 512,
|
| 43 |
+
max_completion_length = 512,
|
| 44 |
+
num_generations = 8, # Number of variations to compare
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
trainer = GRPOTrainer(
|
| 48 |
+
model = model,
|
| 49 |
+
reward_funcs = [reward_compliance, reward_json_format],
|
| 50 |
+
args = training_args,
|
| 51 |
+
train_dataset = [], # We will stream data from your AdGenerator here
|
| 52 |
+
tokenizer = tokenizer,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 4. Start Training
|
| 56 |
+
# trainer.train()
|
inference.py
CHANGED
|
@@ -6,7 +6,7 @@ from openai import OpenAI
|
|
| 6 |
# 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
|
| 7 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 8 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
|
| 9 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3-
|
| 10 |
|
| 11 |
ENV_URL = "http://localhost:8000"
|
| 12 |
MAX_STEPS = 10
|
|
@@ -39,17 +39,28 @@ def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
|
|
| 39 |
|
| 40 |
def get_llm_action(observation_data):
|
| 41 |
"""Asks the LLM what action to take based on the ad observation."""
|
| 42 |
-
system_prompt = """You are an
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
AVAILABLE ACTIONS:
|
|
|
|
| 46 |
- analyze_image
|
|
|
|
| 47 |
- request_landing_page
|
| 48 |
- request_id_verification
|
|
|
|
| 49 |
- approve
|
| 50 |
- reject
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
"""
|
| 54 |
|
| 55 |
user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?"
|
|
@@ -61,17 +72,25 @@ def get_llm_action(observation_data):
|
|
| 61 |
{"role": "system", "content": system_prompt},
|
| 62 |
{"role": "user", "content": user_prompt}
|
| 63 |
],
|
| 64 |
-
response_format={"type": "json_object"}
|
| 65 |
temperature=0.1
|
| 66 |
)
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return {
|
| 70 |
-
"action_type": result.get("action_type", "
|
| 71 |
"reasoning": result.get("reasoning", "Fallback reasoning")
|
| 72 |
}
|
| 73 |
except Exception as e:
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
def main() -> None:
|
| 77 |
for task_id in TASKS:
|
|
@@ -82,29 +101,42 @@ def main() -> None:
|
|
| 82 |
success = False
|
| 83 |
|
| 84 |
try:
|
|
|
|
| 85 |
res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
|
| 86 |
if res.status_code != 200:
|
| 87 |
log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
|
| 88 |
-
# Forced score to 0.01 instead of 0.0
|
| 89 |
log_end(success=False, steps=0, score=0.01, rewards=[])
|
| 90 |
continue
|
| 91 |
|
| 92 |
-
data
|
| 93 |
-
|
|
|
|
| 94 |
done = False
|
| 95 |
|
|
|
|
| 96 |
while not done and steps_taken < MAX_STEPS:
|
| 97 |
steps_taken += 1
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Get action from LLM
|
| 100 |
-
action_payload = get_llm_action(
|
| 101 |
action_str = action_payload["action_type"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
#
|
| 104 |
-
step_res = requests.post(f"{ENV_URL}/step", json=action_payload)
|
| 105 |
-
step_data = step_res.json()
|
| 106 |
-
|
| 107 |
-
# Parse response perfectly
|
| 108 |
observation = step_data.get("observation", {})
|
| 109 |
done = step_data.get("done", False)
|
| 110 |
reward = step_data.get("reward", 0.0)
|
|
@@ -112,17 +144,13 @@ def main() -> None:
|
|
| 112 |
rewards.append(reward)
|
| 113 |
log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
# Calculate final score and forcefully clamp it strictly between 0.01 and 0.99
|
| 117 |
raw_score = sum(rewards)
|
| 118 |
-
|
| 119 |
-
success = score
|
| 120 |
-
|
| 121 |
-
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 122 |
|
| 123 |
except Exception as e:
|
| 124 |
log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
|
| 125 |
-
# Forced score to 0.01 instead of 0.0
|
| 126 |
log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards)
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
|
|
|
| 6 |
# 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
|
| 7 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 8 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
|
| 9 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
|
| 10 |
|
| 11 |
ENV_URL = "http://localhost:8000"
|
| 12 |
MAX_STEPS = 10
|
|
|
|
| 39 |
|
| 40 |
def get_llm_action(observation_data):
|
| 41 |
"""Asks the LLM what action to take based on the ad observation."""
|
| 42 |
+
system_prompt = """You are an enterprise Ad Policy Compliance Agent.
|
| 43 |
+
You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.
|
| 44 |
+
|
| 45 |
+
REQUIRED PHASE ORDER:
|
| 46 |
+
1. query_regulations β always first
|
| 47 |
+
2. analyze_image β required for visual/multimodal tasks
|
| 48 |
+
3. check_advertiser_history or request_landing_page β as needed
|
| 49 |
+
4. submit_audit β always before final decision
|
| 50 |
+
5. approve or reject β final decision only after audit
|
| 51 |
+
|
| 52 |
AVAILABLE ACTIONS:
|
| 53 |
+
- query_regulations
|
| 54 |
- analyze_image
|
| 55 |
+
- check_advertiser_history
|
| 56 |
- request_landing_page
|
| 57 |
- request_id_verification
|
| 58 |
+
- submit_audit
|
| 59 |
- approve
|
| 60 |
- reject
|
| 61 |
+
|
| 62 |
+
Response format:
|
| 63 |
+
{"action_type": "<action>", "reasoning": "<brief reason>"}
|
| 64 |
"""
|
| 65 |
|
| 66 |
user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?"
|
|
|
|
| 72 |
{"role": "system", "content": system_prompt},
|
| 73 |
{"role": "user", "content": user_prompt}
|
| 74 |
],
|
| 75 |
+
# Removed response_format={"type": "json_object"} as HF router often rejects it
|
| 76 |
temperature=0.1
|
| 77 |
)
|
| 78 |
|
| 79 |
+
# Clean the response in case the LLM wrapped it in markdown code blocks like ```json ... ```
|
| 80 |
+
content = response.choices[0].message.content.strip()
|
| 81 |
+
if content.startswith("```json"):
|
| 82 |
+
content = content[7:-3].strip()
|
| 83 |
+
elif content.startswith("```"):
|
| 84 |
+
content = content[3:-3].strip()
|
| 85 |
+
|
| 86 |
+
result = json.loads(content)
|
| 87 |
return {
|
| 88 |
+
"action_type": result.get("action_type", "query_regulations"),
|
| 89 |
"reasoning": result.get("reasoning", "Fallback reasoning")
|
| 90 |
}
|
| 91 |
except Exception as e:
|
| 92 |
+
print(f"\n[CRITICAL LLM ERROR]: {str(e)}\n", flush=True) # THIS WILL REVEAL THE BUG
|
| 93 |
+
return {"action_type": "query_regulations", "reasoning": f"Error recovery: {str(e)}"}
|
| 94 |
|
| 95 |
def main() -> None:
|
| 96 |
for task_id in TASKS:
|
|
|
|
| 101 |
success = False
|
| 102 |
|
| 103 |
try:
|
| 104 |
+
# 1. Reset the environment
|
| 105 |
res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
|
| 106 |
if res.status_code != 200:
|
| 107 |
log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
|
|
|
|
| 108 |
log_end(success=False, steps=0, score=0.01, rewards=[])
|
| 109 |
continue
|
| 110 |
|
| 111 |
+
# 2. Initialize data from the reset
|
| 112 |
+
step_data = res.json()
|
| 113 |
+
observation = step_data.get("observation", step_data)
|
| 114 |
done = False
|
| 115 |
|
| 116 |
+
# 3. THE SINGLE LOOP (Fixed)
|
| 117 |
while not done and steps_taken < MAX_STEPS:
|
| 118 |
steps_taken += 1
|
| 119 |
|
| 120 |
+
# Feedback memory for the LLM
|
| 121 |
+
llm_observation = {
|
| 122 |
+
"task_id": task_id,
|
| 123 |
+
"last_feedback": step_data.get("status_message", "No feedback yet."),
|
| 124 |
+
"step_count": steps_taken,
|
| 125 |
+
"ad_details": observation
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
# Get action from LLM
|
| 129 |
+
action_payload = get_llm_action(llm_observation)
|
| 130 |
action_str = action_payload["action_type"]
|
| 131 |
+
if "Error code: 402" in action_payload.get("reasoning", ""):
|
| 132 |
+
done = True
|
| 133 |
+
log_step(step=steps_taken, action=action_str, reward=0.0, done=True, error="API credits depleted")
|
| 134 |
+
break
|
| 135 |
+
# Execute action in environment
|
| 136 |
+
step_res = requests.post(f"{ENV_URL}/step", json={"action": action_payload})
|
| 137 |
+
step_data = step_res.json()
|
| 138 |
|
| 139 |
+
# Update loop variables
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
observation = step_data.get("observation", {})
|
| 141 |
done = step_data.get("done", False)
|
| 142 |
reward = step_data.get("reward", 0.0)
|
|
|
|
| 144 |
rewards.append(reward)
|
| 145 |
log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
|
| 146 |
|
| 147 |
+
# 4. Final Scoring (Single Log)
|
|
|
|
| 148 |
raw_score = sum(rewards)
|
| 149 |
+
success = raw_score > 0
|
| 150 |
+
log_end(success=success, steps=steps_taken, score=raw_score, rewards=rewards)
|
|
|
|
|
|
|
| 151 |
|
| 152 |
except Exception as e:
|
| 153 |
log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
|
|
|
|
| 154 |
log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards)
|
| 155 |
|
| 156 |
if __name__ == "__main__":
|
server/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
server/__pycache__/app.cpython-313.pyc
ADDED
|
Binary file (876 Bytes). View file
|
|
|
src/__pycache__/environment.cpython-313.pyc
CHANGED
|
Binary files a/src/__pycache__/environment.cpython-313.pyc and b/src/__pycache__/environment.cpython-313.pyc differ
|
|
|
src/__pycache__/generator.cpython-313.pyc
CHANGED
|
Binary files a/src/__pycache__/generator.cpython-313.pyc and b/src/__pycache__/generator.cpython-313.pyc differ
|
|
|
src/environment.py
CHANGED
|
@@ -1,90 +1,136 @@
|
|
| 1 |
-
import
|
| 2 |
from openenv.core.env_server import Environment
|
| 3 |
from src.models import AdAction, AdObservation, AdState
|
| 4 |
from src.generator import AdGenerator
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
class AdPolicyEnvironment(Environment):
|
| 7 |
def __init__(self):
|
| 8 |
super().__init__()
|
| 9 |
self.generator = AdGenerator()
|
| 10 |
self.current_ad = None
|
| 11 |
self.image_analyzed = False
|
|
|
|
|
|
|
| 12 |
self.step_count = 0
|
| 13 |
self.total_reward = 0.0
|
| 14 |
|
| 15 |
-
def _ensure_ad(self):
|
| 16 |
if self.current_ad is None:
|
| 17 |
-
self.current_ad = self.generator.generate_random_ad()
|
|
|
|
| 18 |
|
| 19 |
def state(self) -> AdState:
|
| 20 |
self._ensure_ad()
|
| 21 |
return AdState(
|
| 22 |
step_count=self.step_count,
|
| 23 |
total_reward=self.total_reward,
|
| 24 |
-
current_ad_id=self.current_ad.get("ad_id")
|
| 25 |
)
|
| 26 |
|
| 27 |
-
# Add task_id as an optional parameter
|
| 28 |
def reset(self, task_id: str = None) -> AdObservation:
|
| 29 |
-
# Pass the task_id down to the generator
|
| 30 |
self.current_ad = self.generator.generate_random_ad(task_id)
|
|
|
|
| 31 |
self.image_analyzed = False
|
|
|
|
|
|
|
| 32 |
self.step_count = 0
|
| 33 |
self.total_reward = 0.0
|
| 34 |
-
|
| 35 |
-
# Add the task_id to the welcome message so the bot knows it worked
|
| 36 |
-
msg = f"Ad loaded for {task_id}. Awaiting review." if task_id else "Random ad loaded. Awaiting review."
|
| 37 |
-
return self._get_obs(msg)
|
| 38 |
|
| 39 |
-
def step(self, action: AdAction) -> AdObservation:
|
| 40 |
self._ensure_ad()
|
| 41 |
self.step_count += 1
|
| 42 |
-
|
| 43 |
reward = 0.0
|
| 44 |
done = False
|
| 45 |
-
message = "Action processed."
|
| 46 |
|
| 47 |
if not action or not hasattr(action, 'action_type'):
|
| 48 |
-
|
| 49 |
-
reward = -0.1
|
| 50 |
-
self.total_reward += reward
|
| 51 |
-
return self._get_obs("Invalid action.", reward, False)
|
| 52 |
|
| 53 |
act_type = str(action.action_type).lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
-
# We charge the agent -0.05 for using tools to force efficiency
|
| 57 |
-
if act_type in ["analyze_image", "request_landing_page", "request_id_verification"]:
|
| 58 |
reward = -0.05
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
message = vlm_text # Cleaned up the double "VLM Output:" prefix here!
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
elif act_type == "request_id_verification":
|
| 70 |
-
message = "ID matches advertiser profile."
|
| 71 |
-
|
| 72 |
-
# π― FINAL DECISION: Big Rewards / Big Penalties
|
| 73 |
-
elif act_type in ["approve", "reject"]:
|
| 74 |
done = True
|
| 75 |
is_violation = self.current_ad.get("ground_truth", False)
|
| 76 |
-
is_correct = (act_type == "reject" and is_violation) or
|
| 77 |
-
(act_type == "approve" and not is_violation)
|
| 78 |
-
|
| 79 |
reward = 1.0 if is_correct else -1.0
|
| 80 |
-
message = f"Decision: {act_type.upper()}.
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
|
|
|
| 84 |
|
|
|
|
| 85 |
return self._get_obs(message, reward, done)
|
| 86 |
|
| 87 |
-
def _get_obs(self, message
|
| 88 |
self._ensure_ad()
|
| 89 |
return AdObservation(
|
| 90 |
ad_id=str(self.current_ad.get("ad_id", "N/A")),
|
|
@@ -94,6 +140,6 @@ class AdPolicyEnvironment(Environment):
|
|
| 94 |
targeting_data=dict(self.current_ad.get("targeting_data", {})),
|
| 95 |
image_url=str(self.current_ad.get("image_url", "N/A")),
|
| 96 |
status_message=str(message),
|
| 97 |
-
reward=reward,
|
| 98 |
-
done=done
|
| 99 |
)
|
|
|
|
| 1 |
+
import requests
|
| 2 |
from openenv.core.env_server import Environment
|
| 3 |
from src.models import AdAction, AdObservation, AdState
|
| 4 |
from src.generator import AdGenerator
|
| 5 |
|
| 6 |
+
REGULATORY_API = "http://localhost:8001"
|
| 7 |
+
CRM_API = "http://localhost:8002"
|
| 8 |
+
AUDIT_API = "http://localhost:8003"
|
| 9 |
+
|
| 10 |
class AdPolicyEnvironment(Environment):
|
| 11 |
def __init__(self):
|
| 12 |
super().__init__()
|
| 13 |
self.generator = AdGenerator()
|
| 14 |
self.current_ad = None
|
| 15 |
self.image_analyzed = False
|
| 16 |
+
self.regulations_queried = False
|
| 17 |
+
self.audit_submitted = False
|
| 18 |
self.step_count = 0
|
| 19 |
self.total_reward = 0.0
|
| 20 |
|
| 21 |
+
def _ensure_ad(self, task_id=None):
|
| 22 |
if self.current_ad is None:
|
| 23 |
+
self.current_ad = self.generator.generate_random_ad(task_id)
|
| 24 |
+
self.current_ad["task_id"] = task_id or "task_1_healthcare"
|
| 25 |
|
| 26 |
def state(self) -> AdState:
|
| 27 |
self._ensure_ad()
|
| 28 |
return AdState(
|
| 29 |
step_count=self.step_count,
|
| 30 |
total_reward=self.total_reward,
|
| 31 |
+
current_ad_id=self.current_ad.get("ad_id", "N/A")
|
| 32 |
)
|
| 33 |
|
|
|
|
| 34 |
def reset(self, task_id: str = None) -> AdObservation:
|
|
|
|
| 35 |
self.current_ad = self.generator.generate_random_ad(task_id)
|
| 36 |
+
self.current_ad["task_id"] = task_id or "task_1_healthcare"
|
| 37 |
self.image_analyzed = False
|
| 38 |
+
self.regulations_queried = False
|
| 39 |
+
self.audit_submitted = False
|
| 40 |
self.step_count = 0
|
| 41 |
self.total_reward = 0.0
|
| 42 |
+
return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def step(self, action: AdAction) -> AdObservation:
|
| 45 |
self._ensure_ad()
|
| 46 |
self.step_count += 1
|
|
|
|
| 47 |
reward = 0.0
|
| 48 |
done = False
|
|
|
|
| 49 |
|
| 50 |
if not action or not hasattr(action, 'action_type'):
|
| 51 |
+
return self._get_obs("Invalid action format.", -0.1, False)
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
act_type = str(action.action_type).lower()
|
| 54 |
+
task_id = self.current_ad.get("task_id", "")
|
| 55 |
+
|
| 56 |
+
# ββ TOOL ACTIONS ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
if act_type == "query_regulations":
|
| 58 |
+
self.regulations_queried = True
|
| 59 |
+
reward = -0.05
|
| 60 |
+
category = self.current_ad.get("category", "general")
|
| 61 |
+
try:
|
| 62 |
+
resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
|
| 63 |
+
message = resp.json().get("policy_summary", "Standard policy applies.")
|
| 64 |
+
except Exception:
|
| 65 |
+
message = "API Error: Default standard policy applies."
|
| 66 |
+
|
| 67 |
+
elif act_type == "analyze_image":
|
| 68 |
+
self.image_analyzed = True
|
| 69 |
+
reward = -0.05
|
| 70 |
+
message = self.current_ad.get("vlm_desc", "No visual anomalies detected.")
|
| 71 |
+
|
| 72 |
+
elif act_type == "check_advertiser_history":
|
| 73 |
+
reward = -0.05
|
| 74 |
+
advertiser_id = self.current_ad.get("advertiser_id", "adv_003")
|
| 75 |
+
try:
|
| 76 |
+
resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
|
| 77 |
+
message = f"CRM Summary: {resp.json().get('summary', 'No data')}"
|
| 78 |
+
except Exception:
|
| 79 |
+
message = "CRM offline. Cannot verify history."
|
| 80 |
+
|
| 81 |
+
elif act_type == "request_landing_page":
|
| 82 |
+
reward = -0.05
|
| 83 |
+
domain_age = self.current_ad.get("domain_age_days", 365)
|
| 84 |
+
risk_keywords = self.current_ad.get("landing_risk_keywords", [])
|
| 85 |
+
message = f"Domain age: {domain_age} days. Flagged terms: {risk_keywords or 'none'}."
|
| 86 |
|
| 87 |
+
elif act_type == "request_id_verification":
|
|
|
|
|
|
|
| 88 |
reward = -0.05
|
| 89 |
+
age_min = self.current_ad.get("targeting_data", {}).get("age_min", 18)
|
| 90 |
+
message = f"Target age {age_min}+." if age_min >= 18 else f"ALERT: Minor targeting detected (Age {age_min}+)."
|
| 91 |
+
|
| 92 |
+
elif act_type == "submit_audit":
|
| 93 |
+
self.audit_submitted = True
|
| 94 |
+
reward = 0.0
|
| 95 |
+
try:
|
| 96 |
+
payload = {
|
| 97 |
+
"ad_id": self.current_ad.get("ad_id", "test"),
|
| 98 |
+
"action_taken": "pending",
|
| 99 |
+
"reasoning": "audit requested"
|
| 100 |
+
}
|
| 101 |
+
resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
|
| 102 |
+
message = f"Audit logged: {resp.json().get('audit_id', 'Local-1')}"
|
| 103 |
+
except Exception:
|
| 104 |
+
message = "Audit recorded locally."
|
| 105 |
+
|
| 106 |
+
# ββ TERMINAL ACTIONS (Phase Gates) ββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
elif act_type in ["approve", "reject"]:
|
| 108 |
+
# Gate 1: Must query rules
|
| 109 |
+
if not self.regulations_queried:
|
| 110 |
+
return self._get_obs("Policy Gate: Run query_regulations first.", -0.2, False)
|
| 111 |
|
| 112 |
+
# Gate 2: Multimodal tasks require image analysis
|
| 113 |
+
if "multimodal" in task_id and not self.image_analyzed:
|
| 114 |
+
return self._get_obs("Visual Gate: Image analysis required.", -0.3, False)
|
|
|
|
| 115 |
|
| 116 |
+
# Gate 3: Must audit
|
| 117 |
+
if not self.audit_submitted:
|
| 118 |
+
return self._get_obs("Compliance Gate: Run submit_audit before decision.", -0.2, False)
|
| 119 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
done = True
|
| 121 |
is_violation = self.current_ad.get("ground_truth", False)
|
| 122 |
+
is_correct = (act_type == "reject" and is_violation) or (act_type == "approve" and not is_violation)
|
|
|
|
|
|
|
| 123 |
reward = 1.0 if is_correct else -1.0
|
| 124 |
+
message = f"Decision: {act_type.upper()}. {'Correct!' if is_correct else 'Incorrect.'}"
|
| 125 |
|
| 126 |
+
else:
|
| 127 |
+
reward = -0.05
|
| 128 |
+
message = f"Unknown action: {act_type}."
|
| 129 |
|
| 130 |
+
self.total_reward += reward
|
| 131 |
return self._get_obs(message, reward, done)
|
| 132 |
|
| 133 |
+
def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
|
| 134 |
self._ensure_ad()
|
| 135 |
return AdObservation(
|
| 136 |
ad_id=str(self.current_ad.get("ad_id", "N/A")),
|
|
|
|
| 140 |
targeting_data=dict(self.current_ad.get("targeting_data", {})),
|
| 141 |
image_url=str(self.current_ad.get("image_url", "N/A")),
|
| 142 |
status_message=str(message),
|
| 143 |
+
reward=reward,
|
| 144 |
+
done=done
|
| 145 |
)
|
src/meta_ad_policy_sandbox.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: meta-ad-policy-sandbox
|
| 3 |
+
Version: 0.2.3
|
| 4 |
+
Summary: Meta Ad-Policy RL Sandbox
|
| 5 |
+
Requires-Dist: fastapi
|
| 6 |
+
Requires-Dist: uvicorn
|
| 7 |
+
Requires-Dist: pydantic
|
| 8 |
+
Requires-Dist: requests
|
| 9 |
+
Requires-Dist: openai
|
| 10 |
+
Requires-Dist: openenv-core>=0.2.0
|
src/meta_ad_policy_sandbox.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
src/__init__.py
|
| 4 |
+
src/environment.py
|
| 5 |
+
src/generator.py
|
| 6 |
+
src/models.py
|
| 7 |
+
src/meta_ad_policy_sandbox.egg-info/PKG-INFO
|
| 8 |
+
src/meta_ad_policy_sandbox.egg-info/SOURCES.txt
|
| 9 |
+
src/meta_ad_policy_sandbox.egg-info/dependency_links.txt
|
| 10 |
+
src/meta_ad_policy_sandbox.egg-info/entry_points.txt
|
| 11 |
+
src/meta_ad_policy_sandbox.egg-info/requires.txt
|
| 12 |
+
src/meta_ad_policy_sandbox.egg-info/top_level.txt
|
src/meta_ad_policy_sandbox.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/meta_ad_policy_sandbox.egg-info/entry_points.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[console_scripts]
|
| 2 |
+
server = server.app:main
|
src/meta_ad_policy_sandbox.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
pydantic
|
| 4 |
+
requests
|
| 5 |
+
openai
|
| 6 |
+
openenv-core>=0.2.0
|
src/meta_ad_policy_sandbox.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__init__
|
| 2 |
+
environment
|
| 3 |
+
generator
|
| 4 |
+
models
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|