| from src.environment import AdPolicyEnvironment |
| from src.models import AdAction |
|
|
|
|
| def run_episode(task_id, actions): |
| env = AdPolicyEnvironment() |
| obs = env.reset(task_id=task_id) |
|
|
| for act in actions: |
| obs = env.step( |
| AdAction( |
| action_type=act, |
| reasoning="smoke test", |
| violation_category="NONE", |
| ) |
| ) |
| if obs.done: |
| break |
|
|
| return env, obs |
|
|
|
|
| if __name__ == "__main__": |
| env1, obs1 = run_episode( |
| "task_1_healthcare", |
| [ |
| "query_regulations", |
| "analyze_image", |
| "check_advertiser_history", |
| "submit_audit", |
| "reject", |
| ], |
| ) |
|
|
| assert len(env1.trace) >= 4, f"Trace too short: {len(env1.trace)}" |
| assert isinstance(env1.total_reward, float), "Reward is not numeric" |
| assert all("summary" in t["result"] for t in env1.trace), "Bad trace format" |
|
|
| env2, obs2 = run_episode( |
| "task_10_failure", |
| [ |
| "query_regulations", |
| "query_regulations", |
| "check_advertiser_history", |
| "submit_audit", |
| "reject", |
| ], |
| ) |
|
|
| assert len(env2.trace) >= 2, f"Failure trace too short: {len(env2.trace)}" |
| assert any("API failure" in t["result"]["summary"] for t in env2.trace), ( |
| "Failure case did not trigger" |
| ) |
|
|
| print("STEP 7 SMOKE TEST PASSED") |
| print("\nTRACE 1:") |
| for row in env1.trace: |
| print(row) |
|
|
| print("\nTRACE 2:") |
| for row in env2.trace: |
| print(row) |
|
|
| print("\nTOTAL REWARD 1:", env1.total_reward) |
| print("TOTAL REWARD 2:", env2.total_reward) |