Kartik Goyal commited on
Commit ·
abcf7c3
1
Parent(s): cedcbb4
updated src/models.py · src/environment.py · src/generator.py · apps/crm_api.py · apps/regulatory_api.py · apps/audit_api.py
Browse files- apps/audit_api.py +7 -4
- apps/crm_api.py +13 -4
- apps/regulatory_api.py +28 -22
- src/environment.py +234 -81
- src/generator.py +64 -6
- src/models.py +11 -5
apps/audit_api.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
#audit_api
|
| 2 |
from fastapi import FastAPI
|
| 3 |
from pydantic import BaseModel
|
| 4 |
import uvicorn
|
|
|
|
|
|
|
| 5 |
|
| 6 |
app = FastAPI(title="Compliance Audit API")
|
| 7 |
logs = []
|
|
@@ -13,9 +14,11 @@ class AuditRecord(BaseModel):
|
|
| 13 |
|
| 14 |
@app.post("/log")
|
| 15 |
def log_audit(record: AuditRecord):
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
@app.get("/health")
|
| 20 |
def health():
|
| 21 |
return {"status": "ok", "service": "compliance-audit"}
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
from pydantic import BaseModel
|
| 3 |
import uvicorn
|
| 4 |
+
import random
|
| 5 |
+
import uuid
|
| 6 |
|
| 7 |
app = FastAPI(title="Compliance Audit API")
|
| 8 |
logs = []
|
|
|
|
| 14 |
|
| 15 |
@app.post("/log")
|
| 16 |
def log_audit(record: AuditRecord):
|
| 17 |
+
if random.random() < 0.1:
|
| 18 |
+
return {"error": "service_unavailable", "retryable": True}
|
| 19 |
+
audit_id = f"AUD-{uuid.uuid4().hex[:8].upper()}"
|
| 20 |
+
logs.append({**record.dict(), "audit_id": audit_id})
|
| 21 |
+
return {"status": "logged", "audit_id": audit_id}
|
| 22 |
@app.get("/health")
|
| 23 |
def health():
|
| 24 |
return {"status": "ok", "service": "compliance-audit"}
|
apps/crm_api.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
import uvicorn
|
|
|
|
| 3 |
|
| 4 |
app = FastAPI(title="Advertiser CRM API")
|
| 5 |
|
|
@@ -38,14 +39,22 @@ ADVERTISERS = {
|
|
| 38 |
|
| 39 |
@app.get("/advertiser/{advertiser_id}")
|
| 40 |
def get_advertiser(advertiser_id: str):
|
| 41 |
-
if
|
| 42 |
-
return
|
| 43 |
-
|
|
|
|
| 44 |
"name": "Unknown Advertiser",
|
| 45 |
"prior_violations": 0,
|
| 46 |
"account_age_days": 7,
|
| 47 |
"summary": "New unverified advertiser. No history. Treat with caution."
|
| 48 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
@app.get("/health")
|
| 51 |
def health():
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
import uvicorn
|
| 3 |
+
import random
|
| 4 |
|
| 5 |
app = FastAPI(title="Advertiser CRM API")
|
| 6 |
|
|
|
|
| 39 |
|
| 40 |
@app.get("/advertiser/{advertiser_id}")
|
| 41 |
def get_advertiser(advertiser_id: str):
|
| 42 |
+
if random.random() < 0.1:
|
| 43 |
+
return {"error": "service_unavailable", "retryable": True}
|
| 44 |
+
|
| 45 |
+
data = ADVERTISERS.get(advertiser_id, {
|
| 46 |
"name": "Unknown Advertiser",
|
| 47 |
"prior_violations": 0,
|
| 48 |
"account_age_days": 7,
|
| 49 |
"summary": "New unverified advertiser. No history. Treat with caution."
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
risk_score = min(1.0,
|
| 53 |
+
0.15 * data["prior_violations"] +
|
| 54 |
+
0.5 * (1 / (1 + data["account_age_days"] / 30))
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return {**data, "risk_score": round(risk_score, 2)}
|
| 58 |
|
| 59 |
@app.get("/health")
|
| 60 |
def health():
|
apps/regulatory_api.py
CHANGED
|
@@ -1,45 +1,51 @@
|
|
| 1 |
-
# regulatory_api.py
|
| 2 |
from fastapi import FastAPI
|
| 3 |
import uvicorn
|
|
|
|
| 4 |
|
| 5 |
app = FastAPI(title="Regulatory DB API")
|
| 6 |
|
| 7 |
REGULATIONS = {
|
| 8 |
"healthcare": {
|
| 9 |
-
"
|
| 10 |
-
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
|
| 14 |
-
"risk_level": "high"
|
| 15 |
},
|
| 16 |
"financial": {
|
| 17 |
-
"
|
| 18 |
-
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
|
| 22 |
-
"risk_level": "high"
|
| 23 |
},
|
| 24 |
"targeting": {
|
| 25 |
-
"
|
| 26 |
-
|
| 27 |
-
"
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
},
|
| 31 |
"general": {
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
},
|
| 35 |
"none": {
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
}
|
| 39 |
}
|
| 40 |
|
| 41 |
@app.get("/regulations/{category}")
|
| 42 |
def get_regulations(category: str):
|
|
|
|
|
|
|
| 43 |
return REGULATIONS.get(category.lower(), REGULATIONS["general"])
|
| 44 |
|
| 45 |
@app.get("/health")
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
import uvicorn
|
| 3 |
+
import random
|
| 4 |
|
| 5 |
app = FastAPI(title="Regulatory DB API")
|
| 6 |
|
| 7 |
REGULATIONS = {
|
| 8 |
"healthcare": {
|
| 9 |
+
"policy_hint": "Health claims require FDA verification. No unverified cures or prescription bypasses.",
|
| 10 |
+
"violations": [
|
| 11 |
+
{"type": "unverified_cure_claim", "confidence": 0.9},
|
| 12 |
+
{"type": "prescription_bypass", "confidence": 0.4}
|
| 13 |
+
]
|
|
|
|
| 14 |
},
|
| 15 |
"financial": {
|
| 16 |
+
"policy_hint": "SEC registration required. No guaranteed returns or predatory lending.",
|
| 17 |
+
"violations": [
|
| 18 |
+
{"type": "guaranteed_returns", "confidence": 0.85},
|
| 19 |
+
{"type": "predatory_lending", "confidence": 0.5}
|
| 20 |
+
]
|
|
|
|
| 21 |
},
|
| 22 |
"targeting": {
|
| 23 |
+
"policy_hint": "Age-restricted products cannot target minors. age_min must be >= 18.",
|
| 24 |
+
"violations": [
|
| 25 |
+
{"type": "minor_targeting", "confidence": 0.95}
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
"ambiguous": {
|
| 29 |
+
"policy_hint": "Policy applicability is unclear. Gather additional signals before deciding.",
|
| 30 |
+
"violations": [
|
| 31 |
+
{"type": "possible_misleading_claim", "confidence": 0.45},
|
| 32 |
+
{"type": "unverified_endorsement", "confidence": 0.5}
|
| 33 |
+
]
|
| 34 |
},
|
| 35 |
"general": {
|
| 36 |
+
"policy_hint": "Standard advertising standards apply. No deceptive claims.",
|
| 37 |
+
"violations": []
|
| 38 |
},
|
| 39 |
"none": {
|
| 40 |
+
"policy_hint": "Standard advertising standards apply. No deceptive claims.",
|
| 41 |
+
"violations": []
|
| 42 |
}
|
| 43 |
}
|
| 44 |
|
| 45 |
@app.get("/regulations/{category}")
|
| 46 |
def get_regulations(category: str):
|
| 47 |
+
if random.random() < 0.1:
|
| 48 |
+
return {"error": "service_unavailable", "retryable": True}
|
| 49 |
return REGULATIONS.get(category.lower(), REGULATIONS["general"])
|
| 50 |
|
| 51 |
@app.get("/health")
|
src/environment.py
CHANGED
|
@@ -7,16 +7,44 @@ REGULATORY_API = "http://localhost:8001"
|
|
| 7 |
CRM_API = "http://localhost:8002"
|
| 8 |
AUDIT_API = "http://localhost:8003"
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class AdPolicyEnvironment(Environment):
|
| 11 |
def __init__(self):
|
| 12 |
super().__init__()
|
| 13 |
self.generator = AdGenerator()
|
| 14 |
self.current_ad = None
|
| 15 |
-
self.image_analyzed = False
|
| 16 |
-
self.regulations_queried = False
|
| 17 |
-
self.audit_submitted = False
|
| 18 |
self.step_count = 0
|
| 19 |
self.total_reward = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def _ensure_ad(self, task_id=None):
|
| 22 |
if self.current_ad is None:
|
|
@@ -33,106 +61,226 @@ class AdPolicyEnvironment(Environment):
|
|
| 33 |
|
| 34 |
def reset(self, task_id: str = None) -> AdObservation:
|
| 35 |
self.current_ad = self.generator.generate_random_ad(task_id)
|
| 36 |
-
self.current_ad["task_id"] = task_id or "task_1_healthcare"
|
| 37 |
-
self.image_analyzed = False
|
| 38 |
-
self.regulations_queried = False
|
| 39 |
-
self.audit_submitted = False
|
| 40 |
self.step_count = 0
|
| 41 |
self.total_reward = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
|
| 43 |
|
| 44 |
def step(self, action: AdAction) -> AdObservation:
|
| 45 |
self._ensure_ad()
|
| 46 |
-
self.step_count += 1
|
| 47 |
-
reward = 0.0
|
| 48 |
-
done = False
|
| 49 |
|
| 50 |
if not action or not hasattr(action, 'action_type'):
|
| 51 |
-
return self._get_obs("Invalid action format.", -0.
|
| 52 |
|
| 53 |
act_type = str(action.action_type).lower()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
try:
|
| 76 |
resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
message = f"Target age {age_min}+." if age_min >= 18 else f"ALERT: Minor targeting detected (Age {age_min}+)."
|
| 91 |
-
|
| 92 |
-
elif act_type == "submit_audit":
|
| 93 |
-
self.audit_submitted = True
|
| 94 |
-
reward = 0.0
|
| 95 |
-
try:
|
| 96 |
payload = {
|
| 97 |
"ad_id": self.current_ad.get("ad_id", "test"),
|
| 98 |
"action_taken": "pending",
|
| 99 |
"reasoning": "audit requested"
|
| 100 |
}
|
| 101 |
resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
|
| 102 |
-
|
| 103 |
-
except Exception:
|
| 104 |
-
message = "Audit recorded locally."
|
| 105 |
-
|
| 106 |
-
# ── TERMINAL ACTIONS (Phase Gates) ────────────────────────────────────
|
| 107 |
-
elif act_type in ["approve", "reject"]:
|
| 108 |
-
# Gate 1: Must query rules
|
| 109 |
-
if not self.regulations_queried:
|
| 110 |
-
self.total_reward += -0.2
|
| 111 |
-
return self._get_obs("Policy Gate: Run query_regulations first.", -0.2, False)
|
| 112 |
-
|
| 113 |
-
# Gate 2: Multimodal tasks require image analysis
|
| 114 |
-
if "multimodal" in task_id and not self.image_analyzed:
|
| 115 |
-
self.total_reward += -0.3
|
| 116 |
-
return self._get_obs("Visual Gate: Image analysis required.", -0.3, False)
|
| 117 |
-
|
| 118 |
-
# Gate 3: Must audit
|
| 119 |
-
if not self.audit_submitted:
|
| 120 |
-
self.total_reward += -0.2
|
| 121 |
-
return self._get_obs("Compliance Gate: Run submit_audit before decision.", -0.2, False)
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
else:
|
| 130 |
-
reward =
|
| 131 |
-
message = f"Unknown action: {act_type}."
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
|
|
|
|
|
|
| 136 |
def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
|
| 137 |
self._ensure_ad()
|
| 138 |
return AdObservation(
|
|
@@ -144,5 +292,10 @@ class AdPolicyEnvironment(Environment):
|
|
| 144 |
image_url=str(self.current_ad.get("image_url", "N/A")),
|
| 145 |
status_message=str(message),
|
| 146 |
reward=reward,
|
| 147 |
-
done=done
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
|
|
|
| 7 |
CRM_API = "http://localhost:8002"
|
| 8 |
AUDIT_API = "http://localhost:8003"
|
| 9 |
|
| 10 |
+
VALID_ACTIONS = {
|
| 11 |
+
"query_regulations",
|
| 12 |
+
"analyze_image",
|
| 13 |
+
"check_advertiser_history",
|
| 14 |
+
"request_landing_page",
|
| 15 |
+
"request_id_verification",
|
| 16 |
+
"submit_audit",
|
| 17 |
+
"approve",
|
| 18 |
+
"reject"
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
TERMINAL_ACTIONS = {"approve", "reject"}
|
| 22 |
+
|
| 23 |
+
REQUIRED_BEFORE_TERMINAL = {
|
| 24 |
+
"query_regulations",
|
| 25 |
+
"submit_audit"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
MAX_STEPS = 8
|
| 29 |
class AdPolicyEnvironment(Environment):
|
| 30 |
def __init__(self):
|
| 31 |
super().__init__()
|
| 32 |
self.generator = AdGenerator()
|
| 33 |
self.current_ad = None
|
|
|
|
|
|
|
|
|
|
| 34 |
self.step_count = 0
|
| 35 |
self.total_reward = 0.0
|
| 36 |
+
self.actions_taken = set()
|
| 37 |
+
self.api_failed = False
|
| 38 |
+
self.api_recovered = False
|
| 39 |
+
self.last_failed_action = None
|
| 40 |
+
self.last_error = None
|
| 41 |
+
self.trace = []
|
| 42 |
+
self.signals = {
|
| 43 |
+
"risk_score": None,
|
| 44 |
+
"policy_confidence": None,
|
| 45 |
+
"image_flag": None,
|
| 46 |
+
"landing_flag": None
|
| 47 |
+
}
|
| 48 |
|
| 49 |
def _ensure_ad(self, task_id=None):
|
| 50 |
if self.current_ad is None:
|
|
|
|
| 61 |
|
| 62 |
def reset(self, task_id: str = None) -> AdObservation:
|
| 63 |
self.current_ad = self.generator.generate_random_ad(task_id)
|
| 64 |
+
self.current_ad["task_id"] = task_id or "task_1_healthcare"
|
|
|
|
|
|
|
|
|
|
| 65 |
self.step_count = 0
|
| 66 |
self.total_reward = 0.0
|
| 67 |
+
self.actions_taken = set()
|
| 68 |
+
self.api_failed = False
|
| 69 |
+
self.api_recovered = False
|
| 70 |
+
self.last_failed_action = None
|
| 71 |
+
self.last_error = None
|
| 72 |
+
self.trace = []
|
| 73 |
+
self.signals = {
|
| 74 |
+
"risk_score": None,
|
| 75 |
+
"policy_confidence": None,
|
| 76 |
+
"image_flag": None,
|
| 77 |
+
"landing_flag": None
|
| 78 |
+
}
|
| 79 |
return self._get_obs(f"Ad loaded for {self.current_ad['task_id']}. Begin with query_regulations.")
|
| 80 |
|
| 81 |
def step(self, action: AdAction) -> AdObservation:
|
| 82 |
self._ensure_ad()
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if not action or not hasattr(action, 'action_type'):
|
| 85 |
+
return self._get_obs("Invalid action format.", -0.5, True)
|
| 86 |
|
| 87 |
act_type = str(action.action_type).lower()
|
| 88 |
+
|
| 89 |
+
# 1. Validate action
|
| 90 |
+
if act_type not in VALID_ACTIONS:
|
| 91 |
+
return self._get_obs(f"Invalid action: {act_type}.", -0.5, True)
|
| 92 |
+
|
| 93 |
+
# 2. Start constraint — state based
|
| 94 |
+
if "query_regulations" not in self.actions_taken:
|
| 95 |
+
if act_type != "query_regulations":
|
| 96 |
+
return self._get_obs("Must call query_regulations first.", -0.2, False)
|
| 97 |
+
|
| 98 |
+
self.step_count += 1
|
| 99 |
+
self.actions_taken.add(act_type)
|
| 100 |
+
|
| 101 |
+
# 3. Execute action
|
| 102 |
+
response = self._execute_action(act_type)
|
| 103 |
+
|
| 104 |
+
# 4. Update state
|
| 105 |
+
if "error" in response:
|
| 106 |
+
self.api_failed = True
|
| 107 |
+
self.last_failed_action = act_type
|
| 108 |
+
self.last_error = response["error"]
|
| 109 |
+
else:
|
| 110 |
+
if act_type == self.last_failed_action:
|
| 111 |
+
self.api_recovered = True
|
| 112 |
+
self.last_error = None
|
| 113 |
+
self._extract_signals(act_type, response)
|
| 114 |
+
|
| 115 |
+
# 5. Append trace
|
| 116 |
+
self.trace.append({
|
| 117 |
+
"step": self.step_count,
|
| 118 |
+
"action": act_type,
|
| 119 |
+
"result": self._summarize_response(act_type, response)
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
# 6. Compute reward
|
| 123 |
+
reward = -0.05 # step penalty
|
| 124 |
+
|
| 125 |
+
# 7. Handle terminal
|
| 126 |
+
done = False
|
| 127 |
+
if act_type in TERMINAL_ACTIONS:
|
| 128 |
+
reward += self._terminal_reward(act_type)
|
| 129 |
+
done = True
|
| 130 |
+
elif self.step_count >= MAX_STEPS:
|
| 131 |
+
reward -= 0.5
|
| 132 |
+
done = True
|
| 133 |
+
|
| 134 |
+
self.total_reward += reward
|
| 135 |
+
summary = self._summarize_response(act_type, response)["summary"]
|
| 136 |
+
return self._get_obs(summary, reward, done)
|
| 137 |
+
|
| 138 |
+
def _execute_action(self, act_type: str) -> dict:
|
| 139 |
+
task_id = self.current_ad.get("task_id", "")
|
| 140 |
+
|
| 141 |
+
# Deterministic failure for task_10_failure on step 1
|
| 142 |
+
if task_id == "task_10_failure" and self.step_count == 1:
|
| 143 |
+
return {"error": "service_unavailable", "retryable": True}
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
if act_type == "query_regulations":
|
| 147 |
+
category = self.current_ad.get("category", "general")
|
| 148 |
resp = requests.get(f"{REGULATORY_API}/regulations/{category}", timeout=2)
|
| 149 |
+
return resp.json()
|
| 150 |
+
|
| 151 |
+
elif act_type == "analyze_image":
|
| 152 |
+
vlm_desc = self.current_ad.get("vlm_desc", "")
|
| 153 |
+
violation = any(kw in vlm_desc.lower() for kw in [
|
| 154 |
+
"violation", "banned", "prescription", "fake", "flagged",
|
| 155 |
+
"semaglutide", "adderall", "no rx", "no prescription"
|
| 156 |
+
])
|
| 157 |
+
return {"violation_detected": violation, "description": vlm_desc}
|
| 158 |
+
|
| 159 |
+
elif act_type == "check_advertiser_history":
|
| 160 |
+
advertiser_id = self.current_ad.get("advertiser_id", "adv_003")
|
|
|
|
| 161 |
resp = requests.get(f"{CRM_API}/advertiser/{advertiser_id}", timeout=2)
|
| 162 |
+
return resp.json()
|
| 163 |
+
|
| 164 |
+
elif act_type == "request_landing_page":
|
| 165 |
+
domain_age = self.current_ad.get("domain_age_days", 365)
|
| 166 |
+
risk_keywords = self.current_ad.get("landing_risk_keywords", [])
|
| 167 |
+
suspicious = domain_age < 30 or len(risk_keywords) > 0
|
| 168 |
+
return {"suspicious": suspicious, "domain_age": domain_age, "risk_keywords": risk_keywords}
|
| 169 |
+
|
| 170 |
+
elif act_type == "request_id_verification":
|
| 171 |
+
age_min = self.current_ad.get("targeting_data", {}).get("age_min", 18)
|
| 172 |
+
return {"age_min": age_min, "minor_targeted": age_min < 18}
|
| 173 |
+
|
| 174 |
+
elif act_type == "submit_audit":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
payload = {
|
| 176 |
"ad_id": self.current_ad.get("ad_id", "test"),
|
| 177 |
"action_taken": "pending",
|
| 178 |
"reasoning": "audit requested"
|
| 179 |
}
|
| 180 |
resp = requests.post(f"{AUDIT_API}/log", json=payload, timeout=2)
|
| 181 |
+
return resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
+
else:
|
| 184 |
+
return {"status": "ok"}
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
return {"error": f"service_unavailable", "retryable": True}
|
| 188 |
+
|
| 189 |
+
def _extract_signals(self, action: str, response: dict):
|
| 190 |
+
if action == "check_advertiser_history":
|
| 191 |
+
self.signals["risk_score"] = response.get("risk_score")
|
| 192 |
+
|
| 193 |
+
elif action == "query_regulations":
|
| 194 |
+
violations = response.get("violations", [])
|
| 195 |
+
confs = [v["confidence"] for v in violations]
|
| 196 |
+
self.signals["policy_confidence"] = max(confs, default=0.0)
|
| 197 |
+
|
| 198 |
+
elif action == "analyze_image":
|
| 199 |
+
self.signals["image_flag"] = response.get("violation_detected", False)
|
| 200 |
+
|
| 201 |
+
elif action == "request_landing_page":
|
| 202 |
+
self.signals["landing_flag"] = response.get("suspicious", False)
|
| 203 |
+
|
| 204 |
+
def _summarize_response(self, action: str, response: dict) -> dict:
|
| 205 |
+
if "error" in response:
|
| 206 |
+
return {"summary": "API failure — retryable", "flag": False}
|
| 207 |
+
|
| 208 |
+
if action == "check_advertiser_history":
|
| 209 |
+
rs = response.get("risk_score", 0.0)
|
| 210 |
+
return {"summary": f"risk_score={rs:.2f}", "flag": rs > 0.7}
|
| 211 |
+
|
| 212 |
+
if action == "query_regulations":
|
| 213 |
+
violations = response.get("violations", [])
|
| 214 |
+
conf = max((v["confidence"] for v in violations), default=0.0)
|
| 215 |
+
return {"summary": f"policy_confidence={conf:.2f}", "flag": conf > 0.7}
|
| 216 |
|
| 217 |
+
if action == "analyze_image":
|
| 218 |
+
flagged = response.get("violation_detected", False)
|
| 219 |
+
return {
|
| 220 |
+
"summary": "image_violation_detected" if flagged else "image_clean",
|
| 221 |
+
"flag": flagged
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
if action == "request_landing_page":
|
| 225 |
+
sus = response.get("suspicious", False)
|
| 226 |
+
return {"summary": "landing_suspicious" if sus else "landing_clean", "flag": sus}
|
| 227 |
+
|
| 228 |
+
if action == "request_id_verification":
|
| 229 |
+
minor = response.get("minor_targeted", False)
|
| 230 |
+
age = response.get("age_min", 18)
|
| 231 |
+
return {
|
| 232 |
+
"summary": f"ALERT: minor targeting age={age}" if minor else f"age_min={age} OK",
|
| 233 |
+
"flag": minor
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
if action == "submit_audit":
|
| 237 |
+
audit_id = response.get("audit_id", "LOCAL")
|
| 238 |
+
return {"summary": f"audit_logged id={audit_id}", "flag": False}
|
| 239 |
+
|
| 240 |
+
return {"summary": "ok", "flag": False}
|
| 241 |
+
|
| 242 |
+
def _terminal_reward(self, act_type: str) -> float:
|
| 243 |
+
reward = 0.0
|
| 244 |
+
is_violation = self.current_ad.get("ground_truth", False)
|
| 245 |
+
is_correct = (act_type == "reject" and is_violation) or \
|
| 246 |
+
(act_type == "approve" and not is_violation)
|
| 247 |
+
|
| 248 |
+
# Dominant signal
|
| 249 |
+
reward += 1.0 if is_correct else -1.0
|
| 250 |
+
|
| 251 |
+
# Sequence correctness
|
| 252 |
+
if REQUIRED_BEFORE_TERMINAL.issubset(self.actions_taken):
|
| 253 |
+
reward += 0.2
|
| 254 |
else:
|
| 255 |
+
reward -= 0.2
|
|
|
|
| 256 |
|
| 257 |
+
# API failure handling
|
| 258 |
+
if self.api_failed and self.api_recovered:
|
| 259 |
+
reward += 0.3
|
| 260 |
+
elif self.api_failed and not self.api_recovered:
|
| 261 |
+
reward -= 0.3
|
| 262 |
+
|
| 263 |
+
# Risk-aware penalty
|
| 264 |
+
if act_type == "approve":
|
| 265 |
+
rs = self.signals["risk_score"]
|
| 266 |
+
pc = self.signals["policy_confidence"]
|
| 267 |
+
if rs is not None and pc is not None:
|
| 268 |
+
if rs > 0.7 and pc > 0.6:
|
| 269 |
+
reward -= 0.5
|
| 270 |
+
|
| 271 |
+
# Ambiguity enforcement
|
| 272 |
+
pc = self.signals["policy_confidence"]
|
| 273 |
+
if pc is not None and pc < 0.6:
|
| 274 |
+
if "check_advertiser_history" not in self.actions_taken \
|
| 275 |
+
and "request_landing_page" not in self.actions_taken:
|
| 276 |
+
reward -= 0.4
|
| 277 |
+
|
| 278 |
+
# Landing page bonus under ambiguity
|
| 279 |
+
if self.signals["landing_flag"] and pc is not None and pc < 0.6:
|
| 280 |
+
reward += 0.2
|
| 281 |
|
| 282 |
+
return reward
|
| 283 |
+
|
| 284 |
def _get_obs(self, message, reward=0.0, done=False) -> AdObservation:
|
| 285 |
self._ensure_ad()
|
| 286 |
return AdObservation(
|
|
|
|
| 292 |
image_url=str(self.current_ad.get("image_url", "N/A")),
|
| 293 |
status_message=str(message),
|
| 294 |
reward=reward,
|
| 295 |
+
done=done,
|
| 296 |
+
risk_score=self.signals["risk_score"],
|
| 297 |
+
policy_confidence=self.signals["policy_confidence"],
|
| 298 |
+
image_flag=self.signals["image_flag"],
|
| 299 |
+
landing_flag=self.signals["landing_flag"],
|
| 300 |
+
last_error=self.last_error
|
| 301 |
)
|
src/generator.py
CHANGED
|
@@ -69,7 +69,40 @@ class AdGenerator:
|
|
| 69 |
"A bottle of standard FDA-approved daily vitamins.",
|
| 70 |
False, 400, []),
|
| 71 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
@staticmethod
|
| 74 |
def generate_random_ad(task_id: str = None):
|
| 75 |
# Task overrides
|
|
@@ -105,6 +138,31 @@ class AdGenerator:
|
|
| 105 |
domain_age = 30
|
| 106 |
risk_kw = []
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
scenario_type = random.choices(
|
| 110 |
["obvious", "multimodal", "subtle", "safe"],
|
|
@@ -112,23 +170,23 @@ class AdGenerator:
|
|
| 112 |
k=1
|
| 113 |
)[0]
|
| 114 |
pool_map = {
|
| 115 |
-
"obvious":
|
| 116 |
"multimodal": AdGenerator.MULTIMODAL_TRAPS,
|
| 117 |
-
"subtle":
|
| 118 |
-
"safe":
|
| 119 |
}
|
| 120 |
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 121 |
pool_map[scenario_type]
|
| 122 |
)
|
| 123 |
|
| 124 |
# Trust score — KEY FIX: multimodal traps get HIGH trust to defeat shortcuts
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
elif is_violation:
|
| 128 |
trust_score = round(random.uniform(0.10, 0.50), 2)
|
| 129 |
else:
|
| 130 |
trust_score = round(random.uniform(0.70, 0.99), 2)
|
| 131 |
-
|
| 132 |
# task_4 targeting age — FIX: use "age_min" consistently
|
| 133 |
age_min = 15 if task_id == "task_4_targeting" else random.randint(18, 55)
|
| 134 |
|
|
|
|
| 69 |
"A bottle of standard FDA-approved daily vitamins.",
|
| 70 |
False, 400, []),
|
| 71 |
]
|
| 72 |
+
CONFLICT_ADS = [
|
| 73 |
+
("FINANCIAL", "adv_002", "Smart Investment Tools",
|
| 74 |
+
"Our platform helps you manage your portfolio with ease.",
|
| 75 |
+
"A clean dashboard showing portfolio analytics.",
|
| 76 |
+
True, 20, ["guaranteed-edge"]),
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
AMBIGUOUS_ADS = [
|
| 80 |
+
("ambiguous", "adv_004", "Feel Better Naturally",
|
| 81 |
+
"Our wellness program supports your health journey.",
|
| 82 |
+
"A person meditating in a park.",
|
| 83 |
+
True, 45, ["ships-direct"]),
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
ADVERSARIAL_ADS = [
|
| 87 |
+
("HEALTHCARE", "adv_004", "Natural Energy Boost",
|
| 88 |
+
"Feel amazing every day with our daily supplement.",
|
| 89 |
+
"Fine print reads: contains semaglutide analog, ships without Rx needed.",
|
| 90 |
+
True, 60, []),
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
DEPENDENCY_TRAP_ADS = [
|
| 94 |
+
("HEALTHCARE", "adv_003", "Wellness Support Pack",
|
| 95 |
+
"Certified vitamins to support your daily routine.",
|
| 96 |
+
"Unlabeled pill bottle with small text: contains unapproved compound.",
|
| 97 |
+
True, 400, []),
|
| 98 |
+
]
|
| 99 |
|
| 100 |
+
FAILURE_RECOVERY_ADS = [
|
| 101 |
+
("HEALTHCARE", "adv_001", "Miracle Weight Loss Patch",
|
| 102 |
+
"Lose 10kg in 2 weeks. Guaranteed results or money back.",
|
| 103 |
+
"A before/after image with extreme weight loss claims.",
|
| 104 |
+
True, 30, ["guaranteed", "no-prescription"]),
|
| 105 |
+
]
|
| 106 |
@staticmethod
|
| 107 |
def generate_random_ad(task_id: str = None):
|
| 108 |
# Task overrides
|
|
|
|
| 138 |
domain_age = 30
|
| 139 |
risk_kw = []
|
| 140 |
|
| 141 |
+
elif task_id == "task_6_conflict":
|
| 142 |
+
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 143 |
+
AdGenerator.CONFLICT_ADS
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
elif task_id == "task_7_ambiguous":
|
| 147 |
+
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 148 |
+
AdGenerator.AMBIGUOUS_ADS
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
elif task_id == "task_8_adversarial":
|
| 152 |
+
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 153 |
+
AdGenerator.ADVERSARIAL_ADS
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
elif task_id == "task_9_dependency_trap":
|
| 157 |
+
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 158 |
+
AdGenerator.DEPENDENCY_TRAP_ADS
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
elif task_id == "task_10_failure":
|
| 162 |
+
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 163 |
+
AdGenerator.FAILURE_RECOVERY_ADS
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
else:
|
| 167 |
scenario_type = random.choices(
|
| 168 |
["obvious", "multimodal", "subtle", "safe"],
|
|
|
|
| 170 |
k=1
|
| 171 |
)[0]
|
| 172 |
pool_map = {
|
| 173 |
+
"obvious": AdGenerator.OBVIOUS_VIOLATIONS,
|
| 174 |
"multimodal": AdGenerator.MULTIMODAL_TRAPS,
|
| 175 |
+
"subtle": AdGenerator.SUBTLE_VIOLATIONS,
|
| 176 |
+
"safe": AdGenerator.SAFE_ADS,
|
| 177 |
}
|
| 178 |
category, adv_id, headline, body, vlm_desc, is_violation, domain_age, risk_kw = random.choice(
|
| 179 |
pool_map[scenario_type]
|
| 180 |
)
|
| 181 |
|
| 182 |
# Trust score — KEY FIX: multimodal traps get HIGH trust to defeat shortcuts
|
| 183 |
+
# task_3 + task_6: high trust = stealth trap, forces CRM check
|
| 184 |
+
if task_id in ("task_3_multimodal", "task_6_conflict"):
|
| 185 |
+
trust_score = round(random.uniform(0.82, 0.97), 2)
|
| 186 |
elif is_violation:
|
| 187 |
trust_score = round(random.uniform(0.10, 0.50), 2)
|
| 188 |
else:
|
| 189 |
trust_score = round(random.uniform(0.70, 0.99), 2)
|
|
|
|
| 190 |
# task_4 targeting age — FIX: use "age_min" consistently
|
| 191 |
age_min = 15 if task_id == "task_4_targeting" else random.randint(18, 55)
|
| 192 |
|
src/models.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Literal, Optional, Dict, Any
|
| 2 |
from openenv.core.env_server import Action, Observation, State
|
| 3 |
|
| 4 |
class AdObservation(Observation):
|
|
@@ -9,15 +9,21 @@ class AdObservation(Observation):
|
|
| 9 |
targeting_data: Dict[str, Any]
|
| 10 |
image_url: str
|
| 11 |
status_message: str
|
| 12 |
-
|
| 13 |
-
# 🚨 NEW: OpenEnv requires these to be part of the Observation!
|
| 14 |
reward: float = 0.0
|
| 15 |
done: bool = False
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class AdAction(Action):
|
| 18 |
action_type: Literal[
|
| 19 |
-
"
|
| 20 |
-
"request_landing_page", "request_id_verification"
|
|
|
|
| 21 |
]
|
| 22 |
reasoning: str
|
| 23 |
violation_category: Optional[Literal["HEALTHCARE", "FINANCIAL", "NONE"]] = None
|
|
|
|
| 1 |
+
from typing import Literal, Optional, Dict, Any, List
|
| 2 |
from openenv.core.env_server import Action, Observation, State
|
| 3 |
|
| 4 |
class AdObservation(Observation):
|
|
|
|
| 9 |
targeting_data: Dict[str, Any]
|
| 10 |
image_url: str
|
| 11 |
status_message: str
|
|
|
|
|
|
|
| 12 |
reward: float = 0.0
|
| 13 |
done: bool = False
|
| 14 |
|
| 15 |
+
# signals exposed to agent
|
| 16 |
+
risk_score: Optional[float] = None
|
| 17 |
+
policy_confidence: Optional[float] = None
|
| 18 |
+
image_flag: Optional[bool] = None
|
| 19 |
+
landing_flag: Optional[bool] = None
|
| 20 |
+
last_error: Optional[str] = None
|
| 21 |
+
|
| 22 |
class AdAction(Action):
|
| 23 |
action_type: Literal[
|
| 24 |
+
"query_regulations", "analyze_image", "check_advertiser_history",
|
| 25 |
+
"request_landing_page", "request_id_verification",
|
| 26 |
+
"submit_audit", "approve", "reject"
|
| 27 |
]
|
| 28 |
reasoning: str
|
| 29 |
violation_category: Optional[Literal["HEALTHCARE", "FINANCIAL", "NONE"]] = None
|