File size: 7,434 Bytes
350500c
 
 
 
 
929006e
350500c
 
8a685c0
350500c
47fa380
350500c
 
 
 
 
 
 
 
 
 
 
 
 
929006e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350500c
 
8a685c0
 
 
 
 
 
 
 
 
 
350500c
8a685c0
350500c
8a685c0
350500c
 
8a685c0
350500c
 
8a685c0
47fa380
 
 
 
 
 
8a685c0
 
350500c
 
 
 
 
 
 
 
 
 
 
8a685c0
350500c
 
 
8a685c0
 
 
 
 
 
 
 
350500c
8a685c0
350500c
 
 
8a685c0
 
350500c
 
 
929006e
 
 
 
 
47fa380
 
350500c
8a685c0
350500c
 
929006e
030cdd8
929006e
 
8a685c0
 
 
929006e
350500c
8a685c0
929006e
 
 
8a685c0
 
 
 
 
47fa380
 
8a685c0
 
929006e
8a685c0
929006e
8a685c0
47fa380
 
 
8a685c0
 
 
929006e
8a685c0
929006e
 
 
47fa380
929006e
47fa380
 
 
 
 
 
 
 
 
 
 
 
 
929006e
 
8a685c0
030cdd8
8a685c0
 
350500c
929006e
 
030cdd8
350500c
 
030cdd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import json
import requests
from openai import OpenAI

# 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")

ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
MAX_STEPS = 10

# 2. MANDATORY: Use OpenAI Client pointed at the HF Router
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

# The exact tasks defined in your openenv.yaml
TASKS = [
    "task_1_healthcare",
    "task_2_financial",
    "task_3_multimodal",
    "task_4_targeting"
]

# --- STRICT GRADING LOGGERS ---
def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)

def log_step(step: int, action: str, reward: float, done: bool, error: str = None) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)

def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    success_val = str(success).lower()
    print(f"[END] success={success_val} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ------------------------------

def get_llm_action(observation_data):
    """Asks the LLM what action to take based on the ad observation."""
    system_prompt = """You are an enterprise Ad Policy Compliance Agent.
    You navigate a multi-system compliance workflow. Always respond with ONLY valid JSON.

    REQUIRED PHASE ORDER:
    1. query_regulations   — always first
    2. analyze_image       — required for visual/multimodal tasks
    3. check_advertiser_history or request_landing_page — as needed
    4. submit_audit        — always before final decision
    5. approve or reject   — final decision only after audit

    AVAILABLE ACTIONS:
    - query_regulations
    - analyze_image
    - check_advertiser_history
    - request_landing_page
    - request_id_verification
    - submit_audit
    - approve
    - reject

    HARD RULES:
    - NEVER repeat an action listed in `actions_already_taken`.
    - You MUST progress through the phase order. Do NOT call submit_audit or approve/reject
      before the prerequisite phases are complete.
    - Choose your action_type ONLY from the AVAILABLE ACTIONS list above. Any other value is invalid.

    Response format:
    {"action_type": "<action>", "reasoning": "<brief reason>"}
    """

    user_prompt = f"Current Ad Observation:\n{json.dumps(observation_data, indent=2)}\n\nWhat is your next action?"

    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            # Removed response_format={"type": "json_object"} as HF router often rejects it
            temperature=0.1
        )
        
        # Clean the response in case the LLM wrapped it in markdown code blocks like ```json ... ```
        content = response.choices[0].message.content.strip()
        if content.startswith("```json"):
            content = content[7:-3].strip()
        elif content.startswith("```"):
            content = content[3:-3].strip()
            
        result = json.loads(content)
        return {
            "action_type": result.get("action_type", "query_regulations"),
            "reasoning": result.get("reasoning", "Fallback reasoning")
        }
    except Exception as e:
        print(f"\n[CRITICAL LLM ERROR]: {str(e)}\n", flush=True) # THIS WILL REVEAL THE BUG
        return {"action_type": "query_regulations", "reasoning": f"Error recovery: {str(e)}"}

def main() -> None:
    for task_id in TASKS:
        log_start(task=task_id, env="meta_ad_policy_sandbox", model=MODEL_NAME)
        
        rewards = []
        steps_taken = 0
        success = False
        actions_taken_list: list = []

        try:
            # 1. Reset the environment
            res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
            if res.status_code != 200:
                log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
                log_end(success=False, steps=0, score=0.01, rewards=[])
                continue
                
            # 2. Initialize data from the reset
            step_data = res.json() 
            observation = step_data.get("observation", step_data)
            done = False
            
            # 3. THE SINGLE LOOP (Fixed)
            while not done and steps_taken < MAX_STEPS:
                steps_taken += 1
                
                # Feedback memory for the LLM
                llm_observation = {
                    "task_id": task_id,
                    "last_feedback": step_data.get("status_message", "No feedback yet."),
                    "step_count": steps_taken,
                    "actions_already_taken": actions_taken_list,
                    "ad_details": observation
                }
                
                # Get action from LLM
                action_payload = get_llm_action(llm_observation)
                action_str = action_payload["action_type"]
                if "Error code: 402" in action_payload.get("reasoning", ""):
                    done = True
                    log_step(step=steps_taken, action=action_str, reward=0.0, done=True, error="API credits depleted")
                    break
                # Execute action in environment
                step_res = requests.post(f"{ENV_URL}/step", json={"action": action_payload})
                step_data = step_res.json() 
                
                # Update loop variables
                observation = step_data.get("observation", {})
                done = step_data.get("done", False)
                reward = step_data.get("reward", 0.0)

                rewards.append(reward)

                # Track only actions that actually advanced state. Skip API-failure
                # / invalid-action / wrong-order cases so the agent is free to retry.
                status_msg = (step_data.get("status_message") or "").lower()
                action_failed = (
                    "api failure" in status_msg
                    or "retryable" in status_msg
                    or "invalid action" in status_msg
                    or "must call" in status_msg
                )
                if not action_failed and action_str not in actions_taken_list:
                    actions_taken_list.append(action_str)

                log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
                
            # 4. Final Scoring (Single Log)
            raw_score = sum(rewards)
            success = raw_score > 0
            log_end(success=success, steps=steps_taken, score=raw_score, rewards=rewards)

        except Exception as e:
            log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
            log_end(success=False, steps=steps_taken, score=0.01, rewards=rewards)

if __name__ == "__main__":
    main()