3v324v23 commited on
Commit
8a685c0
Β·
1 Parent(s): 475e2c6

Phase 2 complete: Fixed inference loop and added phase gates

Browse files
apps/audit_api.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import uvicorn
4
+
5
+ app = FastAPI(title="Compliance Audit API")
6
+ logs = []
7
+
8
+ class AuditRecord(BaseModel):
9
+ ad_id: str
10
+ action_taken: str
11
+ reasoning: str
12
+
13
+ @app.post("/log")
14
+ def log_audit(record: AuditRecord):
15
+ logs.append(record.dict())
16
+ return {"status": "success", "audit_id": f"AUD-{len(logs)}"}
17
+
18
+ @app.get("/health")
19
+ def health():
20
+ return {"status": "ok", "service": "compliance-audit"}
21
+
22
+ if __name__ == "__main__":
23
+ uvicorn.run(app, host="0.0.0.0", port=8003)
apps/crm_api.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import uvicorn
3
+
4
+ app = FastAPI(title="Advertiser CRM API")
5
+
6
+ ADVERTISERS = {
7
+ "adv_001": {"prior_violations": 3, "summary": "High-risk repeat offender."},
8
+ "adv_002": {"prior_violations": 0, "summary": "Established clean record."}
9
+ }
10
+
11
+ @app.get("/advertiser/{advertiser_id}")
12
+ def get_advertiser(advertiser_id: str):
13
+ return ADVERTISERS.get(advertiser_id, {"prior_violations": 0, "summary": "New advertiser."})
14
+
15
+ @app.get("/health")
16
+ def health():
17
+ return {"status": "ok", "service": "advertiser-crm"}
18
+
19
+ if __name__ == "__main__":
20
+ uvicorn.run(app, host="0.0.0.0", port=8002)
apps/regulatory_api.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import uvicorn
3
+
4
+ app = FastAPI(title="Regulatory DB API")
5
+
6
+ REGULATIONS = {
7
+ "healthcare": {
8
+ "policy_summary": "Claims require FDA approval. No 'guaranteed results' allowed.",
9
+ "risk_level": "high"
10
+ },
11
+ "financial": {
12
+ "policy_summary": "Requires SEC registration. Prohibited: predatory APR > 36%.",
13
+ "risk_level": "high"
14
+ },
15
+ "general": {
16
+ "policy_summary": "Standard standards apply. No deceptive claims.",
17
+ "risk_level": "low"
18
+ }
19
+ }
20
+
21
+ @app.get("/regulations/{category}")
22
+ def get_regulations(category: str):
23
+ return REGULATIONS.get(category.lower(), REGULATIONS["general"])
24
+
25
+ @app.get("/health")
26
+ def health():
27
+ return {"status": "ok", "service": "regulatory-db"}
28
+
29
+ if __name__ == "__main__":
30
+ uvicorn.run(app, host="0.0.0.0", port=8001)
apps/start_all.bat ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Launching Enterprise Ecosystem...
3
+
4
+ :: This line forces Windows to go to the project root before running anything
5
+ cd /d "%~dp0\.."
6
+
7
+ start "Regulatory API" cmd /k "uv run python apps\regulatory_api.py"
8
+ start "CRM API" cmd /k "uv run python apps\crm_api.py"
9
+ start "Audit API" cmd /k "uv run python apps\audit_api.py"
10
+ start "Environment Server" cmd /k "uv run uvicorn server.app:app --host 0.0.0.0 --port 8000"
grpo_train.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from unsloth import FastLanguageModel, PatchFastRL
3
+ from trl import GRPOTrainer, GRPOConfig
4
+ from src.environment import AdPolicyEnvironment
5
+
6
+ # 1. Load Model with Unsloth
7
+ model, tokenizer = FastLanguageModel.from_pretrained(
8
+ model_name = "unsloth/Llama-3.1-8B-Instruct",
9
+ max_seq_length = 1024,
10
+ load_in_4bit = True,
11
+ )
12
+
13
+ # 2. Define Reward Functions
14
+ def reward_compliance(prompts, completions, **kwargs):
15
+ rewards = []
16
+ for completion in completions:
17
+ # Check if the model called the necessary tools in order
18
+ if "query_regulations" in completion and "submit_audit" in completion:
19
+ rewards.append(2.0)
20
+ else:
21
+ rewards.append(0.0)
22
+ return rewards
23
+
24
+ def reward_json_format(prompts, completions, **kwargs):
25
+ rewards = []
26
+ for completion in completions:
27
+ try:
28
+ import json
29
+ json.loads(completion)
30
+ rewards.append(1.0)
31
+ except:
32
+ rewards.append(0.0)
33
+ return rewards
34
+
35
+ # 3. Configure Trainer
36
+ training_args = GRPOConfig(
37
+ output_dir = "outputs/meta-ad-agent",
38
+ learning_rate = 5e-6,
39
+ num_train_epochs = 1,
40
+ per_device_train_batch_size = 4,
41
+ gradient_accumulation_steps = 4,
42
+ max_prompt_length = 512,
43
+ max_completion_length = 512,
44
+ num_generations = 8, # Number of variations to compare
45
+ )
46
+
47
+ trainer = GRPOTrainer(
48
+ model = model,
49
+ reward_funcs = [reward_compliance, reward_json_format],
50
+ args = training_args,
51
+ train_dataset = [], # We will stream data from your AdGenerator here
52
+ tokenizer = tokenizer,
53
+ )
54
+
55
+ # 4. Start Training
56
+ # trainer.train()
inference.py CHANGED
@@ -6,7 +6,7 @@ from openai import OpenAI
6
  # 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
8
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
9
- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3-70b-chat-hf")
10
 
11
  ENV_URL = "http://localhost:8000"
12
  MAX_STEPS = 10
@@ -39,17 +39,28 @@ def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
39
 
40
  def get_llm_action(observation_data):
41
  """Asks the LLM what action to take based on the ad observation."""
42
- system_prompt = """You are an expert Meta Ad-Policy Moderator AI.
43
- Evaluate the ad and output a decision. Using tools costs -0.05 points, so be efficient.
44
-
 
 
 
 
 
 
 
45
  AVAILABLE ACTIONS:
 
46
  - analyze_image
 
47
  - request_landing_page
48
  - request_id_verification
 
49
  - approve
50
  - reject
51
-
52
- You MUST respond in valid JSON format containing "action_type" and "reasoning".
 
53
  """
54
 
55
  user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?"
@@ -61,17 +72,25 @@ def get_llm_action(observation_data):
61
  {"role": "system", "content": system_prompt},
62
  {"role": "user", "content": user_prompt}
63
  ],
64
- response_format={"type": "json_object"},
65
  temperature=0.1
66
  )
67
 
68
- result = json.loads(response.choices[0].message.content)
 
 
 
 
 
 
 
69
  return {
70
- "action_type": result.get("action_type", "analyze_image"),
71
  "reasoning": result.get("reasoning", "Fallback reasoning")
72
  }
73
  except Exception as e:
74
- return {"action_type": "analyze_image", "reasoning": "Error recovery."}
 
75
 
76
  def main() -> None:
77
  for task_id in TASKS:
@@ -82,29 +101,42 @@ def main() -> None:
82
  success = False
83
 
84
  try:
 
85
  res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
86
  if res.status_code != 200:
87
  log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
88
- # Forced score to 0.01 instead of 0.0
89
  log_end(success=False, steps=0, score=0.01, rewards=[])
90
  continue
91
 
92
- data = res.json()
93
- observation = data.get("observation", data)
 
94
  done = False
95
 
 
96
  while not done and steps_taken < MAX_STEPS:
97
  steps_taken += 1
98
 
 
 
 
 
 
 
 
 
99
  # Get action from LLM
100
- action_payload = get_llm_action(observation)
101
  action_str = action_payload["action_type"]
 
 
 
 
 
 
 
102
 
103
- # Execute action
104
- step_res = requests.post(f"{ENV_URL}/step", json=action_payload)
105
- step_data = step_res.json()
106
-
107
- # Parse response perfectly
108
  observation = step_data.get("observation", {})
109
  done = step_data.get("done", False)
110
  reward = step_data.get("reward", 0.0)
@@ -112,17 +144,13 @@ def main() -> None:
112
  rewards.append(reward)
113
  log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
114
 
115
- # --- THE FIX IS HERE ---
116
- # Calculate final score and forcefully clamp it strictly between 0.01 and 0.99
117
  raw_score = sum(rewards)
118
- score = min(max(raw_score, 0.01), 0.99)
119
- success = score > 0.01
120
-
121
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
122
 
123
  except Exception as e:
124
  log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
125
- # Forced score to 0.01 instead of 0.0
126
  log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards)
127
 
128
  if __name__ == "__main__":
 
6
  # 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
8
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
9
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
10
 
11
  ENV_URL = "http://localhost:8000"
12
  MAX_STEPS = 10
 
39
 
40
  def get_llm_action(observation_data):
41
  """Asks the LLM what action to take based on the ad observation."""
42
+ system_prompt = """You are an enterprise Ad Policy Compliance Agent.
43
+ You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.
44
+
45
+ REQUIRED PHASE ORDER:
46
+ 1. query_regulations β€” always first
47
+ 2. analyze_image β€” required for visual/multimodal tasks
48
+ 3. check_advertiser_history or request_landing_page β€” as needed
49
+ 4. submit_audit β€” always before final decision
50
+ 5. approve or reject β€” final decision only after audit
51
+
52
  AVAILABLE ACTIONS:
53
+ - query_regulations
54
  - analyze_image
55
+ - check_advertiser_history
56
  - request_landing_page
57
  - request_id_verification
58
+ - submit_audit
59
  - approve
60
  - reject
61
+
62
+ Response format:
63
+ {"action_type": "<action>", "reasoning": "<brief reason>"}
64
  """
65
 
66
  user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?"
 
72
  {"role": "system", "content": system_prompt},
73
  {"role": "user", "content": user_prompt}
74
  ],
75
+ # Removed response_format={"type": "json_object"} as HF router often rejects it
76
  temperature=0.1
77
  )
78
 
79
+ # Clean the response in case the LLM wrapped it in markdown code blocks like ```json ... ```
80
+ content = response.choices[0].message.content.strip()
81
+ if content.startswith("```json"):
82
+ content = content[7:-3].strip()
83
+ elif content.startswith("```"):
84
+ content = content[3:-3].strip()
85
+
86
+ result = json.loads(content)
87
  return {
88
+ "action_type": result.get("action_type", "query_regulations"),
89
  "reasoning": result.get("reasoning", "Fallback reasoning")
90
  }
91
  except Exception as e:
92
+ print(f"\n[CRITICAL LLM ERROR]: {str(e)}\n", flush=True) # THIS WILL REVEAL THE BUG
93
+ return {"action_type": "query_regulations", "reasoning": f"Error recovery: {str(e)}"}
94
 
95
  def main() -> None:
96
  for task_id in TASKS:
 
101
  success = False
102
 
103
  try:
104
+ # 1. Reset the environment
105
  res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
106
  if res.status_code != 200:
107
  log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
 
108
  log_end(success=False, steps=0, score=0.01, rewards=[])
109
  continue
110
 
111
+ # 2. Initialize data from the reset
112
+ step_data = res.json()
113
+ observation = step_data.get("observation", step_data)
114
  done = False
115
 
116
+ # 3. THE SINGLE LOOP (Fixed)
117
  while not done and steps_taken < MAX_STEPS:
118
  steps_taken += 1
119
 
120
+ # Feedback memory for the LLM
121
+ llm_observation = {
122
+ "task_id": task_id,
123
+ "last_feedback": step_data.get("status_message", "No feedback yet."),
124
+ "step_count": steps_taken,
125
+ "ad_details": observation
126
+ }
127
+
128
  # Get action from LLM
129
+ action_payload = get_llm_action(llm_observation)
130
  action_str = action_payload["action_type"]
131
+ if "Error code: 402" in action_payload.get("reasoning", ""):
132
+ done = True
133
+ log_step(step=steps_taken, action=action_str, reward=0.0, done=True, error="API credits depleted")
134
+ break
135
+ # Execute action in environment
136
+ step_res = requests.post(f"{ENV_URL}/step", json={"action": action_payload})
137
+ step_data = step_res.json()
138
 
139
+ # Update loop variables
 
 
 
 
140
  observation = step_data.get("observation", {})
141
  done = step_data.get("done", False)
142
  reward = step_data.get("reward", 0.0)
 
144
  rewards.append(reward)
145
  log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
146
 
147
+ # 4. Final Scoring (Single Log)
 
148
  raw_score = sum(rewards)
149
+ success = raw_score > 0
150
+ log_end(success=success, steps=steps_taken, score=raw_score, rewards=rewards)
 
 
151
 
152
  except Exception as e:
153
  log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
 
154
  log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards)
155
 
156
  if __name__ == "__main__":
server/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (150 Bytes). View file
 
server/__pycache__/app.cpython-313.pyc ADDED
Binary file (876 Bytes). View file
 
src/__pycache__/environment.cpython-313.pyc CHANGED
Binary files a/src/__pycache__/environment.cpython-313.pyc and b/src/__pycache__/environment.cpython-313.pyc differ
 
src/__pycache__/generator.cpython-313.pyc CHANGED
Binary files a/src/__pycache__/generator.cpython-313.pyc and b/src/__pycache__/generator.cpython-313.pyc differ
 
src/environment.py CHANGED
@@ -1,90 +1,136 @@
1
- import uuid
2
  from openenv.core.env_server import Environment
3
  from src.models import AdAction, AdObservation, AdState
4
  from src.generator import AdGenerator
5
 
 
 
 
 
6
  class AdPolicyEnvironment(Environment):
7
  def __init__(self):
8
  super().__init__()
9
  self.generator = AdGenerator()
10
  self.current_ad = None
11
  self.image_analyzed = False
 
 
12
  self.step_count = 0
13
  self.total_reward = 0.0
14
 
15
- def _ensure_ad(self):
16
  if self.current_ad is None:
17
- self.current_ad = self.generator.generate_random_ad()
 
18
 
19
  def state(self) -> AdState:
20
  self._ensure_ad()
21
  return AdState(
22
  step_count=self.step_count,
23
  total_reward=self.total_reward,
24
- current_ad_id=self.current_ad.get("ad_id")
25
  )
26
 
27
- # Add task_id as an optional parameter
28
  def reset(self, task_id: str = None) -> AdObservation:
29
- # Pass the task_id down to the generator
30
  self.current_ad = self.generator.generate_random_ad(task_id)
 
31
  self.image_analyzed = False
 
 
32
  self.step_count = 0
33
  self.total_reward = 0.0
34
-
35
- # Add the task_id to the welcome message so the bot knows it worked
36
- msg = f"Ad loaded for {task_id}. Awaiting review." if task_id else "Random ad loaded. Awaiting review."
37
- return self._get_obs(msg)
38
 
39
- def step(self, action: AdAction) -> AdObservation:
40
  self._ensure_ad()
41
  self.step_count += 1
42
-
43
  reward = 0.0
44
  done = False
45
- message = "Action processed."
46
 
47
  if not action or not hasattr(action, 'action_type'):
48
- # Heavy penalty for invalid formatting to train the agent faster
49
- reward = -0.1
50
- self.total_reward += reward
51
- return self._get_obs("Invalid action.", reward, False)
52
 
53
  act_type = str(action.action_type).lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # 🧠 REWARD SHAPING: Tool Usage Penalties
56
- # We charge the agent -0.05 for using tools to force efficiency
57
- if act_type in ["analyze_image", "request_landing_page", "request_id_verification"]:
58
  reward = -0.05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if act_type == "analyze_image":
61
- self.image_analyzed = True
62
- vlm_text = self.current_ad.get('vlm_desc', 'No visual description.')
63
- message = vlm_text # Cleaned up the double "VLM Output:" prefix here!
64
 
65
- elif act_type == "request_landing_page":
66
- is_bad = self.current_ad.get("ground_truth", False)
67
- message = "Landing page is high-risk." if is_bad else "Landing page is clean."
68
-
69
- elif act_type == "request_id_verification":
70
- message = "ID matches advertiser profile."
71
-
72
- # 🎯 FINAL DECISION: Big Rewards / Big Penalties
73
- elif act_type in ["approve", "reject"]:
74
  done = True
75
  is_violation = self.current_ad.get("ground_truth", False)
76
- is_correct = (act_type == "reject" and is_violation) or \
77
- (act_type == "approve" and not is_violation)
78
-
79
  reward = 1.0 if is_correct else -1.0
80
- message = f"Decision: {act_type.upper()}. " + ("Correct!" if is_correct else "Incorrect.")
81
 
82
- # Accumulate the total reward for the episode state
83
- self.total_reward += reward
 
84
 
 
85
  return self._get_obs(message, reward, done)
86
 
87
- def _get_obs(self, message: str, reward: float = 0.0, done: bool = False) -> AdObservation:
88
  self._ensure_ad()
89
  return AdObservation(
90
  ad_id=str(self.current_ad.get("ad_id", "N/A")),
@@ -94,6 +140,6 @@ class AdPolicyEnvironment(Environment):
94
  targeting_data=dict(self.current_ad.get("targeting_data", {})),
95
  image_url=str(self.current_ad.get("image_url", "N/A")),
96
  status_message=str(message),
97
- reward=reward,
98
- done=done
99
  )
 
1
+ import requests
2
  from openenv.core.env_server import Environment
3
  from src.models import AdAction, AdObservation, AdState
4
  from src.generator import AdGenerator
5
 
6
+ REGULATORY_API = "http://localhost:8001"
7
+ CRM_API = "http://localhost:8002"
8
+ AUDIT_API = "http://localhost:8003"
9
+
10
  class AdPolicyEnvironment(Environment):
11
  def __init__(self):
12
  super().__init__()
13
  self.generator = AdGenerator()
14
  self.current_ad = None
15
  self.image_analyzed = False
16
+ self.regulations_queried = False
17
+ self.audit_submitted = False
18
  self.step_count = 0
19
  self.total_reward = 0.0
20
 
21
+ def _ensure_ad(self, task_id=None):
22
  if self.current_ad is None:
23
+ self.current_ad = self.generator.generate_random_ad(task_id)
24
+ self.current_ad["task_id"] = task_id or "task_1_healthcare"
25
 
26
  def state(self) -> AdState:
27
  self._ensure_ad()
28
  return AdState(
29
  step_count=self.step_count,
30
  total_reward=self.total_reward,
31
+ current_ad_id=self.current_ad.get("ad_id", "N/A")
32
  )
33
 
 
34
  def reset(self, task_id: str = None) -> AdObservation:
 
35
  self.current_ad = self.generator.generate_random_ad(task_id)
36
+ self.current_ad["task_id"] = task_id or "task_1_healthcare"
37
  self.image_analyzed = False
38
+ self.regulations_queried = False
39
+ self.audit_submitted = False
40
  self.step_count = 0
41
  self.total_reward = 0.0
42
+ return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
 
 
 
43
 
44
+ def step(self, action: AdAction) -> AdObservation:
45
  self._ensure_ad()
46
  self.step_count += 1
 
47
  reward = 0.0
48
  done = False
 
49
 
50
  if not action or not hasattr(action, 'action_type'):
51
+ return self._get_obs("Invalid action format.", -0.1, False)
 
 
 
52
 
53
  act_type = str(action.action_type).lower()
54
+ task_id = self.current_ad.get("task_id", "")
55
+
56
+ # ── TOOL ACTIONS ──────────────────────────────────────────────────────
57
+ if act_type == "query_regulations":
58
+ self.regulations_queried = True
59
+ reward = -0.05
60
+ category = self.current_ad.get("category", "general")
61
+ try:
62
+ resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
63
+ message = resp.json().get("policy_summary", "Standard policy applies.")
64
+ except Exception:
65
+ message = "API Error: Default standard policy applies."
66
+
67
+ elif act_type == "analyze_image":
68
+ self.image_analyzed = True
69
+ reward = -0.05
70
+ message = self.current_ad.get("vlm_desc", "No visual anomalies detected.")
71
+
72
+ elif act_type == "check_advertiser_history":
73
+ reward = -0.05
74
+ advertiser_id = self.current_ad.get("advertiser_id", "adv_003")
75
+ try:
76
+ resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
77
+ message = f"CRM Summary: {resp.json().get('summary', 'No data')}"
78
+ except Exception:
79
+ message = "CRM offline. Cannot verify history."
80
+
81
+ elif act_type == "request_landing_page":
82
+ reward = -0.05
83
+ domain_age = self.current_ad.get("domain_age_days", 365)
84
+ risk_keywords = self.current_ad.get("landing_risk_keywords", [])
85
+ message = f"Domain age: {domain_age} days. Flagged terms: {risk_keywords or 'none'}."
86
 
87
+ elif act_type == "request_id_verification":
 
 
88
  reward = -0.05
89
+ age_min = self.current_ad.get("targeting_data", {}).get("age_min", 18)
90
+ message = f"Target age {age_min}+." if age_min >= 18 else f"ALERT: Minor targeting detected (Age {age_min}+)."
91
+
92
+ elif act_type == "submit_audit":
93
+ self.audit_submitted = True
94
+ reward = 0.0
95
+ try:
96
+ payload = {
97
+ "ad_id": self.current_ad.get("ad_id", "test"),
98
+ "action_taken": "pending",
99
+ "reasoning": "audit requested"
100
+ }
101
+ resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
102
+ message = f"Audit logged: {resp.json().get('audit_id', 'Local-1')}"
103
+ except Exception:
104
+ message = "Audit recorded locally."
105
+
106
+ # ── TERMINAL ACTIONS (Phase Gates) ────────────────────────────────────
107
+ elif act_type in ["approve", "reject"]:
108
+ # Gate 1: Must query rules
109
+ if not self.regulations_queried:
110
+ return self._get_obs("Policy Gate: Run query_regulations first.", -0.2, False)
111
 
112
+ # Gate 2: Multimodal tasks require image analysis
113
+ if "multimodal" in task_id and not self.image_analyzed:
114
+ return self._get_obs("Visual Gate: Image analysis required.", -0.3, False)
 
115
 
116
+ # Gate 3: Must audit
117
+ if not self.audit_submitted:
118
+ return self._get_obs("Compliance Gate: Run submit_audit before decision.", -0.2, False)
119
+
 
 
 
 
 
120
  done = True
121
  is_violation = self.current_ad.get("ground_truth", False)
122
+ is_correct = (act_type == "reject" and is_violation) or (act_type == "approve" and not is_violation)
 
 
123
  reward = 1.0 if is_correct else -1.0
124
+ message = f"Decision: {act_type.upper()}. {'Correct!' if is_correct else 'Incorrect.'}"
125
 
126
+ else:
127
+ reward = -0.05
128
+ message = f"Unknown action: {act_type}."
129
 
130
+ self.total_reward += reward
131
  return self._get_obs(message, reward, done)
132
 
133
+ def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
134
  self._ensure_ad()
135
  return AdObservation(
136
  ad_id=str(self.current_ad.get("ad_id", "N/A")),
 
140
  targeting_data=dict(self.current_ad.get("targeting_data", {})),
141
  image_url=str(self.current_ad.get("image_url", "N/A")),
142
  status_message=str(message),
143
+ reward=reward,
144
+ done=done
145
  )
src/meta_ad_policy_sandbox.egg-info/PKG-INFO ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: meta-ad-policy-sandbox
3
+ Version: 0.2.3
4
+ Summary: Meta Ad-Policy RL Sandbox
5
+ Requires-Dist: fastapi
6
+ Requires-Dist: uvicorn
7
+ Requires-Dist: pydantic
8
+ Requires-Dist: requests
9
+ Requires-Dist: openai
10
+ Requires-Dist: openenv-core>=0.2.0
src/meta_ad_policy_sandbox.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/__init__.py
4
+ src/environment.py
5
+ src/generator.py
6
+ src/models.py
7
+ src/meta_ad_policy_sandbox.egg-info/PKG-INFO
8
+ src/meta_ad_policy_sandbox.egg-info/SOURCES.txt
9
+ src/meta_ad_policy_sandbox.egg-info/dependency_links.txt
10
+ src/meta_ad_policy_sandbox.egg-info/entry_points.txt
11
+ src/meta_ad_policy_sandbox.egg-info/requires.txt
12
+ src/meta_ad_policy_sandbox.egg-info/top_level.txt
src/meta_ad_policy_sandbox.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/meta_ad_policy_sandbox.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = server.app:main
src/meta_ad_policy_sandbox.egg-info/requires.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ requests
5
+ openai
6
+ openenv-core>=0.2.0
src/meta_ad_policy_sandbox.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __init__
2
+ environment
3
+ generator
4
+ models
uv.lock CHANGED
The diff for this file is too large to render. See raw diff