Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +5 -5
- inference.py +51 -42
README.md
CHANGED
|
@@ -49,7 +49,7 @@ Task definitions live in `tasks.py`.
|
|
| 49 |
- `medium`: diagnosis plus early treatment initiation after iterative lab requests
|
| 50 |
- `hard`: full sepsis management across longer unstable trajectories with stabilization and outcome pressure
|
| 51 |
|
| 52 |
-
Each task has a deterministic grader in `graders.py` that returns a score
|
| 53 |
|
| 54 |
## Action Space
|
| 55 |
|
|
@@ -163,10 +163,10 @@ The script:
|
|
| 163 |
|
| 164 |
Current deterministic baseline scores from the local run:
|
| 165 |
|
| 166 |
-
- `easy`: `
|
| 167 |
-
- `medium`: `
|
| 168 |
-
- `hard`: `0.
|
| 169 |
-
- mean score: `0.
|
| 170 |
|
| 171 |
## Docker
|
| 172 |
|
|
|
|
| 49 |
- `medium`: diagnosis plus early treatment initiation after iterative lab requests
|
| 50 |
- `hard`: full sepsis management across longer unstable trajectories with stabilization and outcome pressure
|
| 51 |
|
| 52 |
+
Each task has a deterministic grader in `graders.py` that returns a score strictly inside `(0.0, 1.0)`.
|
| 53 |
|
| 54 |
## Action Space
|
| 55 |
|
|
|
|
| 163 |
|
| 164 |
Current deterministic baseline scores from the local run:
|
| 165 |
|
| 166 |
+
- `easy`: `0.999`
|
| 167 |
+
- `medium`: `0.999`
|
| 168 |
+
- `hard`: `0.9592`
|
| 169 |
+
- mean score: `0.9857`
|
| 170 |
|
| 171 |
## Docker
|
| 172 |
|
inference.py
CHANGED
|
@@ -572,22 +572,25 @@ def compute_dense_reward_metrics(
|
|
| 572 |
action_history: list[str],
|
| 573 |
) -> dict[str, float | int]:
|
| 574 |
nonzero_rewards = [reward for reward in reward_trace if reward != 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
|
| 576 |
return {
|
| 577 |
"steps_taken": step_count,
|
| 578 |
"total_reward": float(sum(reward_trace)),
|
| 579 |
"reward_count": len(reward_trace),
|
| 580 |
"positive_rewards_count": sum(1 for reward in reward_trace if reward > 0),
|
| 581 |
-
"reward_density":
|
| 582 |
-
if reward_trace
|
| 583 |
-
else 0.0,
|
| 584 |
"avg_reward_per_step": float(np.mean(reward_trace)) if reward_trace else 0.0,
|
| 585 |
"reward_variance": float(np.var(reward_trace)) if reward_trace else 0.0,
|
| 586 |
"max_single_reward": float(max(reward_trace)) if reward_trace else 0.0,
|
| 587 |
-
"episode_length_efficiency":
|
| 588 |
-
"positive_reward_ratio":
|
| 589 |
-
sum(1 for reward in reward_trace if reward > 0) / max(1, len(nonzero_rewards))
|
| 590 |
-
),
|
| 591 |
"unique_actions": len(set(action_history)),
|
| 592 |
"action_entropy": compute_action_entropy(action_history),
|
| 593 |
}
|
|
@@ -682,7 +685,7 @@ def run_task(
|
|
| 682 |
|
| 683 |
if not final_info:
|
| 684 |
final_info = {}
|
| 685 |
-
score =
|
| 686 |
log_end(success=success, steps=step_count, score=score, rewards=reward_trace)
|
| 687 |
|
| 688 |
try:
|
|
@@ -698,16 +701,16 @@ def run_task(
|
|
| 698 |
"task_id": task_id,
|
| 699 |
"episode_id": state.episode_id,
|
| 700 |
"score": normalized_score,
|
| 701 |
-
"avg_reward": metrics.get("avg_reward", 0.
|
| 702 |
-
"detection": metrics.get("detection", 0.
|
| 703 |
-
"lab_workup": metrics.get("lab_workup", 0.
|
| 704 |
-
"treatment": metrics.get("treatment", 0.
|
| 705 |
-
"timeliness": metrics.get("timeliness", 0.
|
| 706 |
-
"stability": metrics.get("stability", 0.
|
| 707 |
-
"safety": metrics.get("safety", 0.
|
| 708 |
-
"safety_violation_rate": metrics.get("safety_violation_rate", 0.
|
| 709 |
"safety_violations": metrics.get("safety_violations", 0),
|
| 710 |
-
"outcome": metrics.get("outcome", 0.
|
| 711 |
"steps": metrics.get("steps", state.step_count),
|
| 712 |
"episode_index": episode_index,
|
| 713 |
"policy_mode": policy_mode,
|
|
@@ -723,16 +726,16 @@ def run_task(
|
|
| 723 |
"task_id": task_id,
|
| 724 |
"episode_id": getattr(state, 'episode_id', 'unknown'),
|
| 725 |
"score": normalize_task_score(0.0),
|
| 726 |
-
"avg_reward": 0.0,
|
| 727 |
-
"detection": 0.0,
|
| 728 |
-
"lab_workup": 0.0,
|
| 729 |
-
"treatment": 0.0,
|
| 730 |
-
"timeliness": 0.0,
|
| 731 |
-
"stability": 0.0,
|
| 732 |
-
"safety": 0.0,
|
| 733 |
-
"safety_violation_rate": 0.0,
|
| 734 |
"safety_violations": 0,
|
| 735 |
-
"outcome": 0.0,
|
| 736 |
"steps": step_count,
|
| 737 |
"episode_index": episode_index,
|
| 738 |
"policy_mode": policy_mode,
|
|
@@ -743,12 +746,12 @@ def run_task(
|
|
| 743 |
"total_reward": 0.0,
|
| 744 |
"reward_count": 0,
|
| 745 |
"positive_rewards_count": 0,
|
| 746 |
-
"reward_density": 0.0,
|
| 747 |
"avg_reward_per_step": 0.0,
|
| 748 |
"reward_variance": 0.0,
|
| 749 |
"max_single_reward": 0.0,
|
| 750 |
-
"episode_length_efficiency": 0.0,
|
| 751 |
-
"positive_reward_ratio": 0.0,
|
| 752 |
"unique_actions": 0,
|
| 753 |
"action_entropy": 0.0,
|
| 754 |
}
|
|
@@ -776,20 +779,24 @@ def summarize_runs(
|
|
| 776 |
return {
|
| 777 |
"results": all_results,
|
| 778 |
"episode_summaries": per_episode_results,
|
| 779 |
-
"mean_score":
|
| 780 |
-
"score_std":
|
| 781 |
-
"mean_score_std":
|
| 782 |
if per_episode_results
|
| 783 |
-
else 0.0,
|
| 784 |
-
"mean_reward_density":
|
| 785 |
-
"global_reward_density":
|
| 786 |
if total_reward_count
|
| 787 |
-
else 0.0,
|
| 788 |
"mean_avg_reward_per_step": round(float(np.mean([item.get("avg_reward_per_step", 0.0) for item in all_results])), 4),
|
| 789 |
"mean_reward_variance": round(float(np.mean([item.get("reward_variance", 0.0) for item in all_results])), 4),
|
| 790 |
-
"mean_positive_reward_ratio":
|
|
|
|
|
|
|
| 791 |
"mean_action_entropy": round(float(np.mean([item.get("action_entropy", 0.0) for item in all_results])), 4),
|
| 792 |
-
"safety_violation_rate":
|
|
|
|
|
|
|
| 793 |
"total_runs": len(all_results),
|
| 794 |
"episodes": len(per_episode_results),
|
| 795 |
"requested_policy": requested_policy,
|
|
@@ -869,11 +876,13 @@ def main() -> None:
|
|
| 869 |
episode_summaries.append(
|
| 870 |
{
|
| 871 |
"episode_index": episode_index,
|
| 872 |
-
"mean_score":
|
| 873 |
-
"mean_reward_density":
|
| 874 |
-
|
|
|
|
|
|
|
| 875 |
if episode_steps
|
| 876 |
-
else 0.0,
|
| 877 |
}
|
| 878 |
)
|
| 879 |
except Exception as exc:
|
|
|
|
| 572 |
action_history: list[str],
|
| 573 |
) -> dict[str, float | int]:
|
| 574 |
nonzero_rewards = [reward for reward in reward_trace if reward != 0]
|
| 575 |
+
reward_density = (
|
| 576 |
+
sum(1 for reward in reward_trace if reward > 0) / len(reward_trace)
|
| 577 |
+
if reward_trace
|
| 578 |
+
else 0.5
|
| 579 |
+
)
|
| 580 |
+
episode_length_efficiency = step_count / max_steps if max_steps else 0.5
|
| 581 |
+
positive_reward_ratio = sum(1 for reward in reward_trace if reward > 0) / max(1, len(nonzero_rewards))
|
| 582 |
|
| 583 |
return {
|
| 584 |
"steps_taken": step_count,
|
| 585 |
"total_reward": float(sum(reward_trace)),
|
| 586 |
"reward_count": len(reward_trace),
|
| 587 |
"positive_rewards_count": sum(1 for reward in reward_trace if reward > 0),
|
| 588 |
+
"reward_density": normalize_task_score(reward_density),
|
|
|
|
|
|
|
| 589 |
"avg_reward_per_step": float(np.mean(reward_trace)) if reward_trace else 0.0,
|
| 590 |
"reward_variance": float(np.var(reward_trace)) if reward_trace else 0.0,
|
| 591 |
"max_single_reward": float(max(reward_trace)) if reward_trace else 0.0,
|
| 592 |
+
"episode_length_efficiency": normalize_task_score(episode_length_efficiency),
|
| 593 |
+
"positive_reward_ratio": normalize_task_score(positive_reward_ratio),
|
|
|
|
|
|
|
| 594 |
"unique_actions": len(set(action_history)),
|
| 595 |
"action_entropy": compute_action_entropy(action_history),
|
| 596 |
}
|
|
|
|
| 685 |
|
| 686 |
if not final_info:
|
| 687 |
final_info = {}
|
| 688 |
+
score = normalize_task_score(final_info.get("metrics", {}).get("score", 0.5))
|
| 689 |
log_end(success=success, steps=step_count, score=score, rewards=reward_trace)
|
| 690 |
|
| 691 |
try:
|
|
|
|
| 701 |
"task_id": task_id,
|
| 702 |
"episode_id": state.episode_id,
|
| 703 |
"score": normalized_score,
|
| 704 |
+
"avg_reward": normalize_task_score(metrics.get("avg_reward", 0.5)),
|
| 705 |
+
"detection": normalize_task_score(metrics.get("detection", 0.5)),
|
| 706 |
+
"lab_workup": normalize_task_score(metrics.get("lab_workup", 0.5)),
|
| 707 |
+
"treatment": normalize_task_score(metrics.get("treatment", 0.5)),
|
| 708 |
+
"timeliness": normalize_task_score(metrics.get("timeliness", 0.5)),
|
| 709 |
+
"stability": normalize_task_score(metrics.get("stability", 0.5)),
|
| 710 |
+
"safety": normalize_task_score(metrics.get("safety", 0.5)),
|
| 711 |
+
"safety_violation_rate": normalize_task_score(metrics.get("safety_violation_rate", 0.5)),
|
| 712 |
"safety_violations": metrics.get("safety_violations", 0),
|
| 713 |
+
"outcome": normalize_task_score(metrics.get("outcome", 0.5)),
|
| 714 |
"steps": metrics.get("steps", state.step_count),
|
| 715 |
"episode_index": episode_index,
|
| 716 |
"policy_mode": policy_mode,
|
|
|
|
| 726 |
"task_id": task_id,
|
| 727 |
"episode_id": getattr(state, 'episode_id', 'unknown'),
|
| 728 |
"score": normalize_task_score(0.0),
|
| 729 |
+
"avg_reward": normalize_task_score(0.0),
|
| 730 |
+
"detection": normalize_task_score(0.0),
|
| 731 |
+
"lab_workup": normalize_task_score(0.0),
|
| 732 |
+
"treatment": normalize_task_score(0.0),
|
| 733 |
+
"timeliness": normalize_task_score(0.0),
|
| 734 |
+
"stability": normalize_task_score(0.0),
|
| 735 |
+
"safety": normalize_task_score(0.0),
|
| 736 |
+
"safety_violation_rate": normalize_task_score(0.0),
|
| 737 |
"safety_violations": 0,
|
| 738 |
+
"outcome": normalize_task_score(0.0),
|
| 739 |
"steps": step_count,
|
| 740 |
"episode_index": episode_index,
|
| 741 |
"policy_mode": policy_mode,
|
|
|
|
| 746 |
"total_reward": 0.0,
|
| 747 |
"reward_count": 0,
|
| 748 |
"positive_rewards_count": 0,
|
| 749 |
+
"reward_density": normalize_task_score(0.0),
|
| 750 |
"avg_reward_per_step": 0.0,
|
| 751 |
"reward_variance": 0.0,
|
| 752 |
"max_single_reward": 0.0,
|
| 753 |
+
"episode_length_efficiency": normalize_task_score(0.0),
|
| 754 |
+
"positive_reward_ratio": normalize_task_score(0.0),
|
| 755 |
"unique_actions": 0,
|
| 756 |
"action_entropy": 0.0,
|
| 757 |
}
|
|
|
|
| 779 |
return {
|
| 780 |
"results": all_results,
|
| 781 |
"episode_summaries": per_episode_results,
|
| 782 |
+
"mean_score": normalize_task_score(np.mean([item.get("score", 0.5) for item in all_results])),
|
| 783 |
+
"score_std": normalize_task_score(np.std([item.get("score", 0.5) for item in all_results])),
|
| 784 |
+
"mean_score_std": normalize_task_score(np.std([item.get("mean_score", 0.5) for item in per_episode_results]))
|
| 785 |
if per_episode_results
|
| 786 |
+
else normalize_task_score(0.0),
|
| 787 |
+
"mean_reward_density": normalize_task_score(np.mean([item.get("reward_density", 0.5) for item in all_results])),
|
| 788 |
+
"global_reward_density": normalize_task_score(total_positive_rewards / total_reward_count)
|
| 789 |
if total_reward_count
|
| 790 |
+
else normalize_task_score(0.0),
|
| 791 |
"mean_avg_reward_per_step": round(float(np.mean([item.get("avg_reward_per_step", 0.0) for item in all_results])), 4),
|
| 792 |
"mean_reward_variance": round(float(np.mean([item.get("reward_variance", 0.0) for item in all_results])), 4),
|
| 793 |
+
"mean_positive_reward_ratio": normalize_task_score(
|
| 794 |
+
np.mean([item.get("positive_reward_ratio", 0.5) for item in all_results])
|
| 795 |
+
),
|
| 796 |
"mean_action_entropy": round(float(np.mean([item.get("action_entropy", 0.0) for item in all_results])), 4),
|
| 797 |
+
"safety_violation_rate": normalize_task_score(total_safety_violations / total_steps)
|
| 798 |
+
if total_steps
|
| 799 |
+
else normalize_task_score(0.0),
|
| 800 |
"total_runs": len(all_results),
|
| 801 |
"episodes": len(per_episode_results),
|
| 802 |
"requested_policy": requested_policy,
|
|
|
|
| 876 |
episode_summaries.append(
|
| 877 |
{
|
| 878 |
"episode_index": episode_index,
|
| 879 |
+
"mean_score": normalize_task_score(np.mean([item.get("score", 0.5) for item in episode_results])),
|
| 880 |
+
"mean_reward_density": normalize_task_score(
|
| 881 |
+
np.mean([item.get("reward_density", 0.5) for item in episode_results])
|
| 882 |
+
),
|
| 883 |
+
"safety_violation_rate": normalize_task_score(episode_safety_violations / episode_steps)
|
| 884 |
if episode_steps
|
| 885 |
+
else normalize_task_score(0.0),
|
| 886 |
}
|
| 887 |
)
|
| 888 |
except Exception as exc:
|