Kartik Goyal commited on
Commit
abcf7c3
·
1 Parent(s): cedcbb4

updated src/models.py · src/environment.py · src/generator.py · apps/crm_api.py · apps/regulatory_api.py · apps/audit_api.py

Browse files
apps/audit_api.py CHANGED
@@ -1,7 +1,8 @@
1
- #audit_api
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  import uvicorn
 
 
5
 
6
  app = FastAPI(title="Compliance Audit API")
7
  logs = []
@@ -13,9 +14,11 @@ class AuditRecord(BaseModel):
13
 
14
  @app.post("/log")
15
  def log_audit(record: AuditRecord):
16
- logs.append(record.dict())
17
- return {"status": "success", "audit_id": f"AUD-{len(logs)}"}
18
-
 
 
19
  @app.get("/health")
20
  def health():
21
  return {"status": "ok", "service": "compliance-audit"}
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  import uvicorn
4
+ import random
5
+ import uuid
6
 
7
  app = FastAPI(title="Compliance Audit API")
8
  logs = []
 
14
 
15
  @app.post("/log")
16
  def log_audit(record: AuditRecord):
17
+ if random.random() < 0.1:
18
+ return {"error": "service_unavailable", "retryable": True}
19
+ audit_id = f"AUD-{uuid.uuid4().hex[:8].upper()}"
20
+ logs.append({**record.dict(), "audit_id": audit_id})
21
+ return {"status": "logged", "audit_id": audit_id}
22
  @app.get("/health")
23
  def health():
24
  return {"status": "ok", "service": "compliance-audit"}
apps/crm_api.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
2
  import uvicorn
 
3
 
4
  app = FastAPI(title="Advertiser CRM API")
5
 
@@ -38,14 +39,22 @@ ADVERTISERS = {
38
 
39
  @app.get("/advertiser/{advertiser_id}")
40
  def get_advertiser(advertiser_id: str):
41
- if advertiser_id in ADVERTISERS:
42
- return ADVERTISERS[advertiser_id]
43
- return {
 
44
  "name": "Unknown Advertiser",
45
  "prior_violations": 0,
46
  "account_age_days": 7,
47
  "summary": "New unverified advertiser. No history. Treat with caution."
48
- }
 
 
 
 
 
 
 
49
 
50
  @app.get("/health")
51
  def health():
 
1
  from fastapi import FastAPI
2
  import uvicorn
3
+ import random
4
 
5
  app = FastAPI(title="Advertiser CRM API")
6
 
 
39
 
40
  @app.get("/advertiser/{advertiser_id}")
41
  def get_advertiser(advertiser_id: str):
42
+ if random.random() < 0.1:
43
+ return {"error": "service_unavailable", "retryable": True}
44
+
45
+ data = ADVERTISERS.get(advertiser_id, {
46
  "name": "Unknown Advertiser",
47
  "prior_violations": 0,
48
  "account_age_days": 7,
49
  "summary": "New unverified advertiser. No history. Treat with caution."
50
+ })
51
+
52
+ risk_score = min(1.0,
53
+ 0.15 * data["prior_violations"] +
54
+ 0.5 * (1 / (1 + data["account_age_days"] / 30))
55
+ )
56
+
57
+ return {**data, "risk_score": round(risk_score, 2)}
58
 
59
  @app.get("/health")
60
  def health():
apps/regulatory_api.py CHANGED
@@ -1,45 +1,51 @@
1
- # regulatory_api.py
2
  from fastapi import FastAPI
3
  import uvicorn
 
4
 
5
  app = FastAPI(title="Regulatory DB API")
6
 
7
  REGULATIONS = {
8
  "healthcare": {
9
- "policy_summary": (
10
- "Health claims require FDA approval. "
11
- "Prohibited: unverified cure claims, 'guaranteed results', "
12
- "prescription drug sales without authorization."
13
- ),
14
- "risk_level": "high"
15
  },
16
  "financial": {
17
- "policy_summary": (
18
- "Financial ads require SEC registration. "
19
- "Prohibited: guaranteed returns, predatory APR above 36%, "
20
- "high-pressure investment tactics."
21
- ),
22
- "risk_level": "high"
23
  },
24
  "targeting": {
25
- "policy_summary": (
26
- "Age-restricted products cannot target minors. "
27
- "Financial and healthcare products require age_min >= 18."
28
- ),
29
- "risk_level": "high"
 
 
 
 
 
 
30
  },
31
  "general": {
32
- "policy_summary": "Standard advertising standards apply. No deceptive claims.",
33
- "risk_level": "low"
34
  },
35
  "none": {
36
- "policy_summary": "Standard advertising standards apply. No deceptive claims.",
37
- "risk_level": "low"
38
  }
39
  }
40
 
41
  @app.get("/regulations/{category}")
42
  def get_regulations(category: str):
 
 
43
  return REGULATIONS.get(category.lower(), REGULATIONS["general"])
44
 
45
  @app.get("/health")
 
 
1
  from fastapi import FastAPI
2
  import uvicorn
3
+ import random
4
 
5
  app = FastAPI(title="Regulatory DB API")
6
 
7
  REGULATIONS = {
8
  "healthcare": {
9
+ "policy_hint": "Health claims require FDA verification. No unverified cures or prescription bypasses.",
10
+ "violations": [
11
+ {"type": "unverified_cure_claim", "confidence": 0.9},
12
+ {"type": "prescription_bypass", "confidence": 0.4}
13
+ ]
 
14
  },
15
  "financial": {
16
+ "policy_hint": "SEC registration required. No guaranteed returns or predatory lending.",
17
+ "violations": [
18
+ {"type": "guaranteed_returns", "confidence": 0.85},
19
+ {"type": "predatory_lending", "confidence": 0.5}
20
+ ]
 
21
  },
22
  "targeting": {
23
+ "policy_hint": "Age-restricted products cannot target minors. age_min must be >= 18.",
24
+ "violations": [
25
+ {"type": "minor_targeting", "confidence": 0.95}
26
+ ]
27
+ },
28
+ "ambiguous": {
29
+ "policy_hint": "Policy applicability is unclear. Gather additional signals before deciding.",
30
+ "violations": [
31
+ {"type": "possible_misleading_claim", "confidence": 0.45},
32
+ {"type": "unverified_endorsement", "confidence": 0.5}
33
+ ]
34
  },
35
  "general": {
36
+ "policy_hint": "Standard advertising standards apply. No deceptive claims.",
37
+ "violations": []
38
  },
39
  "none": {
40
+ "policy_hint": "Standard advertising standards apply. No deceptive claims.",
41
+ "violations": []
42
  }
43
  }
44
 
45
  @app.get("/regulations/{category}")
46
  def get_regulations(category: str):
47
+ if random.random() < 0.1:
48
+ return {"error": "service_unavailable", "retryable": True}
49
  return REGULATIONS.get(category.lower(), REGULATIONS["general"])
50
 
51
  @app.get("/health")
src/environment.py CHANGED
@@ -7,16 +7,44 @@ REGULATORY_API = "http://localhost:8001"
7
  CRM_API = "http://localhost:8002"
8
  AUDIT_API = "http://localhost:8003"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class AdPolicyEnvironment(Environment):
11
  def __init__(self):
12
  super().__init__()
13
  self.generator = AdGenerator()
14
  self.current_ad = None
15
- self.image_analyzed = False
16
- self.regulations_queried = False
17
- self.audit_submitted = False
18
  self.step_count = 0
19
  self.total_reward = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def _ensure_ad(self, task_id=None):
22
  if self.current_ad is None:
@@ -33,106 +61,226 @@ class AdPolicyEnvironment(Environment):
33
 
34
  def reset(self, task_id: str = None) -> AdObservation:
35
  self.current_ad = self.generator.generate_random_ad(task_id)
36
- self.current_ad["task_id"] = task_id or "task_1_healthcare"
37
- self.image_analyzed = False
38
- self.regulations_queried = False
39
- self.audit_submitted = False
40
  self.step_count = 0
41
  self.total_reward = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
42
  return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
43
 
44
  def step(self, action: AdAction) -> AdObservation:
45
  self._ensure_ad()
46
- self.step_count += 1
47
- reward = 0.0
48
- done = False
49
 
50
  if not action or not hasattr(action, 'action_type'):
51
- return self._get_obs("Invalid action format.", -0.1, False)
52
 
53
  act_type = str(action.action_type).lower()
54
- task_id = self.current_ad.get("task_id", "")
55
-
56
- # ── TOOL ACTIONS ──────────────────────────────────────────────────────
57
- if act_type == "query_regulations":
58
- self.regulations_queried = True
59
- reward = -0.05
60
- category = self.current_ad.get("category", "general")
61
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
63
- message = resp.json().get("policy_summary", "Standard policy applies.")
64
- except Exception:
65
- message = "API Error: Default standard policy applies."
66
-
67
- elif act_type == "analyze_image":
68
- self.image_analyzed = True
69
- reward = -0.05
70
- message = self.current_ad.get("vlm_desc", "No visual anomalies detected.")
71
-
72
- elif act_type == "check_advertiser_history":
73
- reward = -0.05
74
- advertiser_id = self.current_ad.get("advertiser_id", "adv_003")
75
- try:
76
  resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
77
- message = f"CRM Summary: {resp.json().get('summary', 'No data')}"
78
- except Exception:
79
- message = "CRM offline. Cannot verify history."
80
-
81
- elif act_type == "request_landing_page":
82
- reward = -0.05
83
- domain_age = self.current_ad.get("domain_age_days", 365)
84
- risk_keywords = self.current_ad.get("landing_risk_keywords", [])
85
- message = f"Domain age: {domain_age} days. Flagged terms: {risk_keywords or 'none'}."
86
-
87
- elif act_type == "request_id_verification":
88
- reward = -0.05
89
- age_min = self.current_ad.get("targeting_data", {}).get("age_min", 18)
90
- message = f"Target age {age_min}+." if age_min >= 18 else f"ALERT: Minor targeting detected (Age {age_min}+)."
91
-
92
- elif act_type == "submit_audit":
93
- self.audit_submitted = True
94
- reward = 0.0
95
- try:
96
  payload = {
97
  "ad_id": self.current_ad.get("ad_id", "test"),
98
  "action_taken": "pending",
99
  "reasoning": "audit requested"
100
  }
101
  resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
102
- message = f"Audit logged: {resp.json().get('audit_id', 'Local-1')}"
103
- except Exception:
104
- message = "Audit recorded locally."
105
-
106
- # ── TERMINAL ACTIONS (Phase Gates) ────────────────────────────────────
107
- elif act_type in ["approve", "reject"]:
108
- # Gate 1: Must query rules
109
- if not self.regulations_queried:
110
- self.total_reward += -0.2
111
- return self._get_obs("Policy Gate: Run query_regulations first.", -0.2, False)
112
-
113
- # Gate 2: Multimodal tasks require image analysis
114
- if "multimodal" in task_id and not self.image_analyzed:
115
- self.total_reward += -0.3
116
- return self._get_obs("Visual Gate: Image analysis required.", -0.3, False)
117
-
118
- # Gate 3: Must audit
119
- if not self.audit_submitted:
120
- self.total_reward += -0.2
121
- return self._get_obs("Compliance Gate: Run submit_audit before decision.", -0.2, False)
122
 
123
- done = True
124
- is_violation = self.current_ad.get("ground_truth", False)
125
- is_correct = (act_type == "reject" and is_violation) or (act_type == "approve" and not is_violation)
126
- reward = 1.0 if is_correct else -1.0
127
- message = f"Decision: {act_type.upper()}. {'Correct!' if is_correct else 'Incorrect.'}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  else:
130
- reward = -0.05
131
- message = f"Unknown action: {act_type}."
132
 
133
- self.total_reward += reward
134
- return self._get_obs(message, reward, done)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
 
 
136
  def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
137
  self._ensure_ad()
138
  return AdObservation(
@@ -144,5 +292,10 @@ class AdPolicyEnvironment(Environment):
144
  image_url=str(self.current_ad.get("image_url", "N/A")),
145
  status_message=str(message),
146
  reward=reward,
147
- done=done
 
 
 
 
 
148
  )
 
7
  CRM_API = "http://localhost:8002"
8
  AUDIT_API = "http://localhost:8003"
9
 
10
+ VALID_ACTIONS = {
11
+ "query_regulations",
12
+ "analyze_image",
13
+ "check_advertiser_history",
14
+ "request_landing_page",
15
+ "request_id_verification",
16
+ "submit_audit",
17
+ "approve",
18
+ "reject"
19
+ }
20
+
21
+ TERMINAL_ACTIONS = {"approve", "reject"}
22
+
23
+ REQUIRED_BEFORE_TERMINAL = {
24
+ "query_regulations",
25
+ "submit_audit"
26
+ }
27
+
28
+ MAX_STEPS = 8
29
  class AdPolicyEnvironment(Environment):
30
  def __init__(self):
31
  super().__init__()
32
  self.generator = AdGenerator()
33
  self.current_ad = None
 
 
 
34
  self.step_count = 0
35
  self.total_reward = 0.0
36
+ self.actions_taken = set()
37
+ self.api_failed = False
38
+ self.api_recovered = False
39
+ self.last_failed_action = None
40
+ self.last_error = None
41
+ self.trace = []
42
+ self.signals = {
43
+ "risk_score": None,
44
+ "policy_confidence": None,
45
+ "image_flag": None,
46
+ "landing_flag": None
47
+ }
48
 
49
  def _ensure_ad(self, task_id=None):
50
  if self.current_ad is None:
 
61
 
62
  def reset(self, task_id: str = None) -> AdObservation:
63
  self.current_ad = self.generator.generate_random_ad(task_id)
64
+ self.current_ad["task_id"] = task_id or "task_1_healthcare"
 
 
 
65
  self.step_count = 0
66
  self.total_reward = 0.0
67
+ self.actions_taken = set()
68
+ self.api_failed = False
69
+ self.api_recovered = False
70
+ self.last_failed_action = None
71
+ self.last_error = None
72
+ self.trace = []
73
+ self.signals = {
74
+ "risk_score": None,
75
+ "policy_confidence": None,
76
+ "image_flag": None,
77
+ "landing_flag": None
78
+ }
79
  return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
80
 
81
  def step(self, action: AdAction) -> AdObservation:
82
  self._ensure_ad()
 
 
 
83
 
84
  if not action or not hasattr(action, 'action_type'):
85
+ return self._get_obs("Invalid action format.", -0.5, True)
86
 
87
  act_type = str(action.action_type).lower()
88
+
89
+ # 1. Validate action
90
+ if act_type not in VALID_ACTIONS:
91
+ return self._get_obs(f"Invalid action: {act_type}.", -0.5, True)
92
+
93
+ # 2. Start constraint — state based
94
+ if "query_regulations" not in self.actions_taken:
95
+ if act_type != "query_regulations":
96
+ return self._get_obs("Must call query_regulations first.", -0.2, False)
97
+
98
+ self.step_count += 1
99
+ self.actions_taken.add(act_type)
100
+
101
+ # 3. Execute action
102
+ response = self._execute_action(act_type)
103
+
104
+ # 4. Update state
105
+ if "error" in response:
106
+ self.api_failed = True
107
+ self.last_failed_action = act_type
108
+ self.last_error = response["error"]
109
+ else:
110
+ if act_type == self.last_failed_action:
111
+ self.api_recovered = True
112
+ self.last_error = None
113
+ self._extract_signals(act_type, response)
114
+
115
+ # 5. Append trace
116
+ self.trace.append({
117
+ "step": self.step_count,
118
+ "action": act_type,
119
+ "result": self._summarize_response(act_type, response)
120
+ })
121
+
122
+ # 6. Compute reward
123
+ reward = -0.05 # step penalty
124
+
125
+ # 7. Handle terminal
126
+ done = False
127
+ if act_type in TERMINAL_ACTIONS:
128
+ reward += self._terminal_reward(act_type)
129
+ done = True
130
+ elif self.step_count >= MAX_STEPS:
131
+ reward -= 0.5
132
+ done = True
133
+
134
+ self.total_reward += reward
135
+ summary = self._summarize_response(act_type, response)["summary"]
136
+ return self._get_obs(summary, reward, done)
137
+
138
+ def _execute_action(self, act_type: str) -> dict:
139
+ task_id = self.current_ad.get("task_id", "")
140
+
141
+ # Deterministic failure for task_10_failure on step 1
142
+ if task_id == "task_10_failure" and self.step_count == 1:
143
+ return {"error": "service_unavailable", "retryable": True}
144
+
145
+ try:
146
+ if act_type == "query_regulations":
147
+ category = self.current_ad.get("category", "general")
148
  resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
149
+ return resp.json()
150
+
151
+ elif act_type == "analyze_image":
152
+ vlm_desc = self.current_ad.get("vlm_desc", "")
153
+ violation = any(kw in vlm_desc.lower() for kw in [
154
+ "violation", "banned", "prescription", "fake", "flagged",
155
+ "semaglutide", "adderall", "no rx", "no prescription"
156
+ ])
157
+ return {"violation_detected": violation, "description": vlm_desc}
158
+
159
+ elif act_type == "check_advertiser_history":
160
+ advertiser_id = self.current_ad.get("advertiser_id", "adv_003")
 
161
  resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
162
+ return resp.json()
163
+
164
+ elif act_type == "request_landing_page":
165
+ domain_age = self.current_ad.get("domain_age_days", 365)
166
+ risk_keywords = self.current_ad.get("landing_risk_keywords", [])
167
+ suspicious = domain_age < 30 or len(risk_keywords) > 0
168
+ return {"suspicious": suspicious, "domain_age": domain_age, "risk_keywords": risk_keywords}
169
+
170
+ elif act_type == "request_id_verification":
171
+ age_min = self.current_ad.get("targeting_data", {}).get("age_min", 18)
172
+ return {"age_min": age_min, "minor_targeted": age_min < 18}
173
+
174
+ elif act_type == "submit_audit":
 
 
 
 
 
 
175
  payload = {
176
  "ad_id": self.current_ad.get("ad_id", "test"),
177
  "action_taken": "pending",
178
  "reasoning": "audit requested"
179
  }
180
  resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
181
+ return resp.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ else:
184
+ return {"status": "ok"}
185
+
186
+ except Exception as e:
187
+ return {"error": f"service_unavailable", "retryable": True}
188
+
189
+ def _extract_signals(self, action: str, response: dict):
190
+ if action == "check_advertiser_history":
191
+ self.signals["risk_score"] = response.get("risk_score")
192
+
193
+ elif action == "query_regulations":
194
+ violations = response.get("violations", [])
195
+ confs = [v["confidence"] for v in violations]
196
+ self.signals["policy_confidence"] = max(confs, default=0.0)
197
+
198
+ elif action == "analyze_image":
199
+ self.signals["image_flag"] = response.get("violation_detected", False)
200
+
201
+ elif action == "request_landing_page":
202
+ self.signals["landing_flag"] = response.get("suspicious", False)
203
+
204
+ def _summarize_response(self, action: str, response: dict) -> dict:
205
+ if "error" in response:
206
+ return {"summary": "API failure — retryable", "flag": False}
207
+
208
+ if action == "check_advertiser_history":
209
+ rs = response.get("risk_score", 0.0)
210
+ return {"summary": f"risk_score={rs:.2f}", "flag": rs > 0.7}
211
+
212
+ if action == "query_regulations":
213
+ violations = response.get("violations", [])
214
+ conf = max((v["confidence"] for v in violations), default=0.0)
215
+ return {"summary": f"policy_confidence={conf:.2f}", "flag": conf > 0.7}
216
 
217
+ if action == "analyze_image":
218
+ flagged = response.get("violation_detected", False)
219
+ return {
220
+ "summary": "image_violation_detected" if flagged else "image_clean",
221
+ "flag": flagged
222
+ }
223
+
224
+ if action == "request_landing_page":
225
+ sus = response.get("suspicious", False)
226
+ return {"summary": "landing_suspicious" if sus else "landing_clean", "flag": sus}
227
+
228
+ if action == "request_id_verification":
229
+ minor = response.get("minor_targeted", False)
230
+ age = response.get("age_min", 18)
231
+ return {
232
+ "summary": f"ALERT: minor targeting age={age}" if minor else f"age_min={age} OK",
233
+ "flag": minor
234
+ }
235
+
236
+ if action == "submit_audit":
237
+ audit_id = response.get("audit_id", "LOCAL")
238
+ return {"summary": f"audit_logged id={audit_id}", "flag": False}
239
+
240
+ return {"summary": "ok", "flag": False}
241
+
242
+ def _terminal_reward(self, act_type: str) -> float:
243
+ reward = 0.0
244
+ is_violation = self.current_ad.get("ground_truth", False)
245
+ is_correct = (act_type == "reject" and is_violation) or \
246
+ (act_type == "approve" and not is_violation)
247
+
248
+ # Dominant signal
249
+ reward += 1.0 if is_correct else -1.0
250
+
251
+ # Sequence correctness
252
+ if REQUIRED_BEFORE_TERMINAL.issubset(self.actions_taken):
253
+ reward += 0.2
254
  else:
255
+ reward -= 0.2
 
256
 
257
+ # API failure handling
258
+ if self.api_failed and self.api_recovered:
259
+ reward += 0.3
260
+ elif self.api_failed and not self.api_recovered:
261
+ reward -= 0.3
262
+
263
+ # Risk-aware penalty
264
+ if act_type == "approve":
265
+ rs = self.signals["risk_score"]
266
+ pc = self.signals["policy_confidence"]
267
+ if rs is not None and pc is not None:
268
+ if rs > 0.7 and pc > 0.6:
269
+ reward -= 0.5
270
+
271
+ # Ambiguity enforcement
272
+ pc = self.signals["policy_confidence"]
273
+ if pc is not None and pc < 0.6:
274
+ if "check_advertiser_history" not in self.actions_taken \
275
+ and "request_landing_page" not in self.actions_taken:
276
+ reward -= 0.4
277
+
278
+ # Landing page bonus under ambiguity
279
+ if self.signals["landing_flag"] and pc is not None and pc < 0.6:
280
+ reward += 0.2
281
 
282
+ return reward
283
+
284
  def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
285
  self._ensure_ad()
286
  return AdObservation(
 
292
  image_url=str(self.current_ad.get("image_url", "N/A")),
293
  status_message=str(message),
294
  reward=reward,
295
+ done=done,
296
+ risk_score=self.signals["risk_score"],
297
+ policy_confidence=self.signals["policy_confidence"],
298
+ image_flag=self.signals["image_flag"],
299
+ landing_flag=self.signals["landing_flag"],
300
+ last_error=self.last_error
301
  )
src/generator.py CHANGED
@@ -69,7 +69,40 @@ class AdGenerator:
69
  "A bottle of standard FDA-approved daily vitamins.",
70
  False, 400, []),
71
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
 
73
  @staticmethod
74
  def generate_random_ad(task_id: str = None):
75
  # Task overrides
@@ -105,6 +138,31 @@ class AdGenerator:
105
  domain_age = 30
106
  risk_kw = []
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
  scenario_type = random.choices(
110
  ["obvious", "multimodal", "subtle", "safe"],
@@ -112,23 +170,23 @@ class AdGenerator:
112
  k=1
113
  )[0]
114
  pool_map = {
115
- "obvious": AdGenerator.OBVIOUS_VIOLATIONS,
116
  "multimodal": AdGenerator.MULTIMODAL_TRAPS,
117
- "subtle": AdGenerator.SUBTLE_VIOLATIONS,
118
- "safe": AdGenerator.SAFE_ADS,
119
  }
120
  category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
121
  pool_map[scenario_type]
122
  )
123
 
124
  # Trust score — KEY FIX: multimodal traps get HIGH trust to defeat shortcuts
125
- if task_id == "task_3_multimodal":
126
- trust_score = round(random.uniform(0.82, 0.97), 2) # high trust = stealth trap
 
127
  elif is_violation:
128
  trust_score = round(random.uniform(0.10, 0.50), 2)
129
  else:
130
  trust_score = round(random.uniform(0.70, 0.99), 2)
131
-
132
  # task_4 targeting age — FIX: use "age_min" consistently
133
  age_min = 15 if task_id == "task_4_targeting" else random.randint(18, 55)
134
 
 
69
  "A bottle of standard FDA-approved daily vitamins.",
70
  False, 400, []),
71
  ]
72
+ CONFLICT_ADS = [
73
+ ("FINANCIAL", "adv_002", "Smart Investment Tools",
74
+ "Our platform helps you manage your portfolio with ease.",
75
+ "A clean dashboard showing portfolio analytics.",
76
+ True, 20, ["guaranteed-edge"]),
77
+ ]
78
+
79
+ AMBIGUOUS_ADS = [
80
+ ("ambiguous", "adv_004", "Feel Better Naturally",
81
+ "Our wellness program supports your health journey.",
82
+ "A person meditating in a park.",
83
+ True, 45, ["ships-direct"]),
84
+ ]
85
+
86
+ ADVERSARIAL_ADS = [
87
+ ("HEALTHCARE", "adv_004", "Natural Energy Boost",
88
+ "Feel amazing every day with our daily supplement.",
89
+ "Fine print reads: contains semaglutide analog, ships without Rx needed.",
90
+ True, 60, []),
91
+ ]
92
+
93
+ DEPENDENCY_TRAP_ADS = [
94
+ ("HEALTHCARE", "adv_003", "Wellness Support Pack",
95
+ "Certified vitamins to support your daily routine.",
96
+ "Unlabeled pill bottle with small text: contains unapproved compound.",
97
+ True, 400, []),
98
+ ]
99
 
100
+ FAILURE_RECOVERY_ADS = [
101
+ ("HEALTHCARE", "adv_001", "Miracle Weight Loss Patch",
102
+ "Lose 10kg in 2 weeks. Guaranteed results or money back.",
103
+ "A before/after image with extreme weight loss claims.",
104
+ True, 30, ["guaranteed", "no-prescription"]),
105
+ ]
106
  @staticmethod
107
  def generate_random_ad(task_id: str = None):
108
  # Task overrides
 
138
  domain_age = 30
139
  risk_kw = []
140
 
141
+ elif task_id == "task_6_conflict":
142
+ category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
143
+ AdGenerator.CONFLICT_ADS
144
+ )
145
+
146
+ elif task_id == "task_7_ambiguous":
147
+ category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
148
+ AdGenerator.AMBIGUOUS_ADS
149
+ )
150
+
151
+ elif task_id == "task_8_adversarial":
152
+ category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
153
+ AdGenerator.ADVERSARIAL_ADS
154
+ )
155
+
156
+ elif task_id == "task_9_dependency_trap":
157
+ category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
158
+ AdGenerator.DEPENDENCY_TRAP_ADS
159
+ )
160
+
161
+ elif task_id == "task_10_failure":
162
+ category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
163
+ AdGenerator.FAILURE_RECOVERY_ADS
164
+ )
165
+
166
  else:
167
  scenario_type = random.choices(
168
  ["obvious", "multimodal", "subtle", "safe"],
 
170
  k=1
171
  )[0]
172
  pool_map = {
173
+ "obvious": AdGenerator.OBVIOUS_VIOLATIONS,
174
  "multimodal": AdGenerator.MULTIMODAL_TRAPS,
175
+ "subtle": AdGenerator.SUBTLE_VIOLATIONS,
176
+ "safe": AdGenerator.SAFE_ADS,
177
  }
178
  category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
179
  pool_map[scenario_type]
180
  )
181
 
182
  # Trust score — KEY FIX: multimodal traps get HIGH trust to defeat shortcuts
183
+ # task_3 + task_6: high trust = stealth trap, forces CRM check
184
+ if task_id in ("task_3_multimodal", "task_6_conflict"):
185
+ trust_score = round(random.uniform(0.82, 0.97), 2)
186
  elif is_violation:
187
  trust_score = round(random.uniform(0.10, 0.50), 2)
188
  else:
189
  trust_score = round(random.uniform(0.70, 0.99), 2)
 
190
  # task_4 targeting age — FIX: use "age_min" consistently
191
  age_min = 15 if task_id == "task_4_targeting" else random.randint(18, 55)
192
 
src/models.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Literal, Optional, Dict, Any
2
  from openenv.core.env_server import Action, Observation, State
3
 
4
  class AdObservation(Observation):
@@ -9,15 +9,21 @@ class AdObservation(Observation):
9
  targeting_data: Dict[str, Any]
10
  image_url: str
11
  status_message: str
12
-
13
- # 🚨 NEW: OpenEnv requires these to be part of the Observation!
14
  reward: float = 0.0
15
  done: bool = False
16
 
 
 
 
 
 
 
 
17
  class AdAction(Action):
18
  action_type: Literal[
19
- "approve", "reject", "analyze_image",
20
- "request_landing_page", "request_id_verification"
 
21
  ]
22
  reasoning: str
23
  violation_category: Optional[Literal["HEALTHCARE", "FINANCIAL", "NONE"]] = None
 
1
+ from typing import Literal, Optional, Dict, Any, List
2
  from openenv.core.env_server import Action, Observation, State
3
 
4
  class AdObservation(Observation):
 
9
  targeting_data: Dict[str, Any]
10
  image_url: str
11
  status_message: str
 
 
12
  reward: float = 0.0
13
  done: bool = False
14
 
15
+ # signals exposed to agent
16
+ risk_score: Optional[float] = None
17
+ policy_confidence: Optional[float] = None
18
+ image_flag: Optional[bool] = None
19
+ landing_flag: Optional[bool] = None
20
+ last_error: Optional[str] = None
21
+
22
  class AdAction(Action):
23
  action_type: Literal[
24
+ "query_regulations", "analyze_image", "check_advertiser_history",
25
+ "request_landing_page", "request_id_verification",
26
+ "submit_audit", "approve", "reject"
27
  ]
28
  reasoning: str
29
  violation_category: Optional[Literal["HEALTHCARE", "FINANCIAL", "NONE"]] = None