Jash commited on
Commit ·
53effed
1
Parent(s): 0ad1e71
fix: use strict stdout log format for inference script
Browse files- inference.py +7 -37
inference.py
CHANGED
|
@@ -934,17 +934,7 @@ def run_episode(
|
|
| 934 |
print(f"Required specialization: {observation['required_specialization']}")
|
| 935 |
print("Objective: admit patient successfully (no fixed deadline window)")
|
| 936 |
print("=" * 72)
|
| 937 |
-
|
| 938 |
-
"START",
|
| 939 |
-
{
|
| 940 |
-
"task_id": task_id,
|
| 941 |
-
"seed": seed,
|
| 942 |
-
"difficulty": observation.get("scenario_difficulty"),
|
| 943 |
-
"scenario": observation.get("scenario_name"),
|
| 944 |
-
"patient_condition": observation.get("patient_condition"),
|
| 945 |
-
"required_specialization": observation.get("required_specialization"),
|
| 946 |
-
},
|
| 947 |
-
)
|
| 948 |
|
| 949 |
if learning_profile:
|
| 950 |
print(
|
|
@@ -955,6 +945,7 @@ def run_episode(
|
|
| 955 |
print(f"Best known route: {' -> '.join(learning_profile['best_actions'])}")
|
| 956 |
|
| 957 |
total_reward = 0.0
|
|
|
|
| 958 |
steps = 0
|
| 959 |
done = False
|
| 960 |
previous_policy_hospital_id: str | None = None
|
|
@@ -991,6 +982,7 @@ def run_episode(
|
|
| 991 |
)
|
| 992 |
next_obs_model = step_result["observation"]
|
| 993 |
reward = float(step_result["reward"])
|
|
|
|
| 994 |
done = bool(step_result["done"])
|
| 995 |
info = step_result.get("info", {}) or {}
|
| 996 |
next_observation = next_obs_model.model_dump()
|
|
@@ -1005,20 +997,8 @@ def run_episode(
|
|
| 1005 |
print(f"Outcome: {status}")
|
| 1006 |
print(f"Reason: {reason}")
|
| 1007 |
print(f"Reward: {reward:.3f}")
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
{
|
| 1011 |
-
"task_id": task_id,
|
| 1012 |
-
"seed": seed,
|
| 1013 |
-
"step": observation.get("step"),
|
| 1014 |
-
"phase": observation.get("ambulance_status"),
|
| 1015 |
-
"hospital_id": chosen["hospital_id"],
|
| 1016 |
-
"strategy": strategy,
|
| 1017 |
-
"status": status,
|
| 1018 |
-
"reward": round(reward, 4),
|
| 1019 |
-
"done": done,
|
| 1020 |
-
},
|
| 1021 |
-
)
|
| 1022 |
|
| 1023 |
append_trajectory_log(
|
| 1024 |
{
|
|
@@ -1068,18 +1048,8 @@ def run_episode(
|
|
| 1068 |
print(f" Total steps: {steps}")
|
| 1069 |
print(f" Final score: {final_score:.3f}")
|
| 1070 |
print(f" Average reward: {total_reward / max(1, steps):.3f}")
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
{
|
| 1074 |
-
"task_id": task_id,
|
| 1075 |
-
"seed": seed,
|
| 1076 |
-
"result": final_result,
|
| 1077 |
-
"success": final_result == "SUCCESS",
|
| 1078 |
-
"score": round(final_score, 4),
|
| 1079 |
-
"steps": steps,
|
| 1080 |
-
"average_reward": round(total_reward / max(1, steps), 4),
|
| 1081 |
-
},
|
| 1082 |
-
)
|
| 1083 |
|
| 1084 |
return {
|
| 1085 |
"success": final_result == "SUCCESS",
|
|
|
|
| 934 |
print(f"Required specialization: {observation['required_specialization']}")
|
| 935 |
print("Objective: admit patient successfully (no fixed deadline window)")
|
| 936 |
print("=" * 72)
|
| 937 |
+
print(f"[START] task={task_id} env=acde-openenv model={model_name or 'none'}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
|
| 939 |
if learning_profile:
|
| 940 |
print(
|
|
|
|
| 945 |
print(f"Best known route: {' -> '.join(learning_profile['best_actions'])}")
|
| 946 |
|
| 947 |
total_reward = 0.0
|
| 948 |
+
all_rewards = []
|
| 949 |
steps = 0
|
| 950 |
done = False
|
| 951 |
previous_policy_hospital_id: str | None = None
|
|
|
|
| 982 |
)
|
| 983 |
next_obs_model = step_result["observation"]
|
| 984 |
reward = float(step_result["reward"])
|
| 985 |
+
all_rewards.append(reward)
|
| 986 |
done = bool(step_result["done"])
|
| 987 |
info = step_result.get("info", {}) or {}
|
| 988 |
next_observation = next_obs_model.model_dump()
|
|
|
|
| 997 |
print(f"Outcome: {status}")
|
| 998 |
print(f"Reason: {reason}")
|
| 999 |
print(f"Reward: {reward:.3f}")
|
| 1000 |
+
error_val = str(info.get("last_action_error")) if info.get("last_action_error") else "null"
|
| 1001 |
+
print(f"[STEP] step={observation.get('step')} action={chosen['hospital_id']} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
append_trajectory_log(
|
| 1004 |
{
|
|
|
|
| 1048 |
print(f" Total steps: {steps}")
|
| 1049 |
print(f" Final score: {final_score:.3f}")
|
| 1050 |
print(f" Average reward: {total_reward / max(1, steps):.3f}")
|
| 1051 |
+
rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
|
| 1052 |
+
print(f"[END] success={str(final_result == 'SUCCESS').lower()} steps={steps} score={final_score:.2f} rewards={rewards_str}", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
|
| 1054 |
return {
|
| 1055 |
"success": final_result == "SUCCESS",
|