| import os |
| import json |
| import requests |
| from openai import OpenAI |
|
|
| |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token") |
| MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") |
|
|
| ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") |
| MAX_STEPS = 10 |
|
|
| |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| |
| TASKS = [ |
| "task_1_healthcare", |
| "task_2_financial", |
| "task_3_multimodal", |
| "task_4_targeting" |
| ] |
|
|
| |
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: str = None) -> None: |
| error_val = error if error else "null" |
| done_val = str(done).lower() |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) |
|
|
| def log_end(success: bool, steps: int, score: float, rewards: list) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| success_val = str(success).lower() |
| print(f"[END] success={success_val} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) |
| |
|
|
| def get_llm_action(observation_data): |
| """Asks the LLM what action to take based on the ad observation.""" |
| system_prompt = """You are an enterprise Ad Policy Compliance Agent. |
| You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON. |
| |
| REQUIRED PHASE ORDER: |
| 1. query_regulations β always first |
| 2. analyze_image β required for visual/multimodal tasks |
| 3. check_advertiser_history or request_landing_page β as needed |
| 4. submit_audit β always before final decision |
| 5. approve or reject β final decision only after audit |
| |
| AVAILABLE ACTIONS: |
| - query_regulations |
| - analyze_image |
| - check_advertiser_history |
| - request_landing_page |
| - request_id_verification |
| - submit_audit |
| - approve |
| - reject |
| |
| HARD RULES: |
| - NEVER repeat an action listed in `actions_already_taken`. |
| - You MUST progress through the phase order. Do NOT call submit_audit or approve/reject |
| before the prerequisite phases are complete. |
| - Choose your action_type ONLY from the AVAILABLE ACTIONS list above. Any other value is invalid. |
| |
| Response format: |
| {"action_type": "<action>", "reasoning": "<brief reason>"} |
| """ |
|
|
| user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?" |
|
|
| try: |
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ], |
| |
| temperature=0.1 |
| ) |
| |
| |
| content = response.choices[0].message.content.strip() |
| if content.startswith("```json"): |
| content = content[7:-3].strip() |
| elif content.startswith("```"): |
| content = content[3:-3].strip() |
| |
| result = json.loads(content) |
| return { |
| "action_type": result.get("action_type", "query_regulations"), |
| "reasoning": result.get("reasoning", "Fallback reasoning") |
| } |
| except Exception as e: |
| print(f"\n[CRITICAL LLM ERROR]: {str(e)}\n", flush=True) |
| return {"action_type": "query_regulations", "reasoning": f"Error recovery: {str(e)}"} |
|
|
| def main() -> None: |
| for task_id in TASKS: |
| log_start(task=task_id, env="meta_ad_policy_sandbox", model=MODEL_NAME) |
| |
| rewards = [] |
| steps_taken = 0 |
| success = False |
| actions_taken_list: list = [] |
|
|
| try: |
| |
| res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}) |
| if res.status_code != 200: |
| log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}") |
| log_end(success=False, steps=0, score=0.01, rewards=[]) |
| continue |
| |
| |
| step_data = res.json() |
| observation = step_data.get("observation", step_data) |
| done = False |
| |
| |
| while not done and steps_taken < MAX_STEPS: |
| steps_taken += 1 |
| |
| |
| llm_observation = { |
| "task_id": task_id, |
| "last_feedback": step_data.get("status_message", "No feedback yet."), |
| "step_count": steps_taken, |
| "actions_already_taken": actions_taken_list, |
| "ad_details": observation |
| } |
| |
| |
| action_payload = get_llm_action(llm_observation) |
| action_str = action_payload["action_type"] |
| if "Error code: 402" in action_payload.get("reasoning", ""): |
| done = True |
| log_step(step=steps_taken, action=action_str, reward=0.0, done=True, error="API credits depleted") |
| break |
| |
| step_res = requests.post(f"{ENV_URL}/step", json={"action": action_payload}) |
| step_data = step_res.json() |
| |
| |
| observation = step_data.get("observation", {}) |
| done = step_data.get("done", False) |
| reward = step_data.get("reward", 0.0) |
|
|
| rewards.append(reward) |
|
|
| |
| |
| status_msg = (step_data.get("status_message") or "").lower() |
| action_failed = ( |
| "api failure" in status_msg |
| or "retryable" in status_msg |
| or "invalid action" in status_msg |
| or "must call" in status_msg |
| ) |
| if not action_failed and action_str not in actions_taken_list: |
| actions_taken_list.append(action_str) |
|
|
| log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None) |
| |
| |
| raw_score = sum(rewards) |
| success = raw_score > 0 |
| log_end(success=success, steps=steps_taken, score=raw_score, rewards=rewards) |
|
|
| except Exception as e: |
| log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " ")) |
| log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards) |
|
|
| if __name__ == "__main__": |
| main() |