Spaces:
Sleeping
Sleeping
| """Print exactly what the env shows the model. | |
| Walks one mini episode through the same client the smoke trainer uses, | |
| dumping the full observation at reset and after every step so we can | |
| verify nothing leaks and the labels match the values. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from typing import Any | |
| from transformers import AutoTokenizer | |
| from dispatch_arena.client import DispatchArenaClient | |
| from dispatch_arena.scripts.train_grpo_smoke import ( | |
| MODEL_NAME, | |
| SYSTEM_PROMPT, | |
| DispatchToolEnv, | |
| _summarize, | |
| ) | |
| from dispatch_arena.server.app import run_local_server_in_thread | |
| def _dump(label: str, obs) -> None: | |
| print(f"\n=== {label} ===") | |
| print(" summary_text:", obs.summary_text) | |
| print(" reward (this step):", obs.reward) | |
| print(" reward_breakdown:", obs.reward_breakdown.to_dict()) | |
| print(" verifier_status:", obs.verifier_status.value) | |
| print(" done:", obs.done, "truncated:", obs.truncated) | |
| print(" legal_actions:", obs.legal_actions) | |
| print(" action_mask:", obs.action_mask) | |
| print(" info:", obs.info) | |
| print(" state.tick:", obs.state.tick, "/", obs.state.max_ticks) | |
| print(" state.total_reward (cumulative):", obs.state.total_reward) | |
| courier = obs.state.couriers[0] | |
| order = obs.state.orders[0] | |
| print(" courier:", courier.to_dict()) | |
| print(" order:", order.to_dict()) | |
| blob = json.dumps(obs.to_dict()) | |
| leak = "prep_remaining" in blob | |
| print(" leak('prep_remaining' present):", leak) | |
| def main() -> None: | |
| server, _thread = run_local_server_in_thread(port=0, max_concurrent_envs=4) | |
| host, port = server.server_address | |
| time.sleep(0.2) | |
| base_url = f"http://{host}:{port}" | |
| client = DispatchArenaClient(base_url=base_url) | |
| print("### Raw HTTP-client view of the env ###") | |
| obs = client.reset(seed=7, config={"mode": "mini", "max_ticks": 12}) | |
| _dump("RESET (seed=7)", obs) | |
| plan = ["go_pickup", "wait", "wait", "pickup", "go_dropoff", "dropoff"] | |
| for i, action in enumerate(plan, 1): | |
| if obs.done: | |
| print(f"\n(stop: episode ended before action {i})") | |
| break | |
| if action not in obs.legal_actions: | |
| print(f"\n(stop: '{action}' not legal at step {i}, legal={obs.legal_actions})") | |
| break | |
| obs = client.step(action) | |
| _dump(f"STEP {i}: {action}", obs) | |
| print("\n### What DispatchToolEnv.reset returns to TRL ###") | |
| tool_env = DispatchToolEnv() | |
| tool_env.client = client # reuse the same server | |
| initial = tool_env.reset(seed=7) | |
| print(initial) | |
| print("metrics after reset:", tool_env.metrics) | |
| print("\n### One tool-call's text return ###") | |
| out = tool_env.go_pickup() | |
| print(out) | |
| print("metrics after go_pickup:", tool_env.metrics) | |
| print("\n### Full prompt the model actually sees (after TRL appends reset string) ###") | |
| tok = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| tools_schema = [] | |
| for name in ("wait", "go_pickup", "pickup", "go_dropoff", "dropoff"): | |
| method = getattr(DispatchToolEnv, name) | |
| tools_schema.append( | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": name, | |
| "description": (method.__doc__ or "").strip(), | |
| "parameters": {"type": "object", "properties": {}}, | |
| }, | |
| } | |
| ) | |
| user_content = "Begin the shift. " + initial # mirrors what TRL does after reset | |
| rendered = tok.apply_chat_template( | |
| [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| tools=tools_schema, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| print(rendered) | |
| if __name__ == "__main__": | |
| main() | |