| """Collect trajectories with direct OpenEnv environment access.""" | |
| from __future__ import annotations | |
| import argparse | |
| import random | |
| import uuid | |
| from pathlib import Path | |
| from typing import Dict, List, Sequence | |
| from models import ActionType, ExperimentAction | |
| from server.hackathon_environment import BioExperimentEnvironment | |
| from training.evaluation import EvaluationSuite | |
| from training.trajectory import Trajectory, TrajectoryDataset | |
| HEURISTIC_SEQUENCE = [ | |
| ActionType.COLLECT_SAMPLE, | |
| ActionType.PREPARE_LIBRARY, | |
| ActionType.SEQUENCE_CELLS, | |
| ActionType.RUN_QC, | |
| ActionType.FILTER_DATA, | |
| ActionType.NORMALIZE_DATA, | |
| ActionType.CLUSTER_CELLS, | |
| ActionType.TRAJECTORY_ANALYSIS, | |
| ActionType.MARKER_SELECTION, | |
| ActionType.SYNTHESIZE_CONCLUSION, | |
| ] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Run rollout episodes and persist trajectories." | |
| ) | |
| parser.add_argument("--episodes", type=int, default=10, help="Number of episodes.") | |
| parser.add_argument( | |
| "--policy", | |
| choices=["random", "heuristic"], | |
| default="heuristic", | |
| help="Policy to use for rollouts.", | |
| ) | |
| parser.add_argument( | |
| "--max-steps", | |
| type=int, | |
| default=None, | |
| help="Optional hard cutoff per episode (defaults to env limit).", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="training/rollouts", | |
| help="Directory for JSON trajectory outputs.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=None, help="RNG seed.") | |
| return parser.parse_args() | |
| def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType: | |
| seen = set(history) | |
| for action in HEURISTIC_SEQUENCE: | |
| if action not in seen: | |
| return action | |
| if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen: | |
| return ActionType.VALIDATE_MARKER | |
| if ActionType.SYNTHESIZE_CONCLUSION in seen: | |
| return ActionType.SYNTHESIZE_CONCLUSION | |
| return ActionType.SYNTHESIZE_CONCLUSION | |
| def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType: | |
| if policy == "random": | |
| return random.choice(list(ActionType)) | |
| return heuristic_next_action(history, step_index) | |
| def default_comparison_name(conditions: Sequence[str]) -> str: | |
| normalized = {condition.lower() for condition in conditions} | |
| if {"healthy", "ipf"} <= normalized: | |
| return "IPF_vs_healthy" | |
| if any("treated" in condition for condition in normalized) and any( | |
| "untreated" in condition for condition in normalized | |
| ): | |
| return "treated_vs_untreated" | |
| if any("healthy" in condition for condition in normalized): | |
| return "disease_vs_healthy" | |
| return "disease_vs_healthy" | |
| def build_experiment_action( | |
| action_type: ActionType, | |
| discovered_markers: Sequence[str], | |
| conditions: Sequence[str], | |
| ) -> ExperimentAction: | |
| method = None | |
| parameters: Dict[str, object] = {} | |
| if action_type == ActionType.COLLECT_SAMPLE: | |
| parameters = {"n_samples": 6} | |
| elif action_type == ActionType.PREPARE_LIBRARY: | |
| method = "10x_chromium" | |
| elif action_type == ActionType.RUN_QC: | |
| method = "scanpy.pp.calculate_qc_metrics" | |
| elif action_type == ActionType.FILTER_DATA: | |
| method = "scanpy.pp.filter_cells" | |
| elif action_type == ActionType.NORMALIZE_DATA: | |
| method = "scanpy.pp.normalize_total" | |
| elif action_type == ActionType.CLUSTER_CELLS: | |
| method = "scanpy.tl.leiden" | |
| elif action_type == ActionType.DIFFERENTIAL_EXPRESSION: | |
| method = "scanpy.tl.rank_genes_groups" | |
| parameters = {"comparison": default_comparison_name(conditions)} | |
| elif action_type == ActionType.TRAJECTORY_ANALYSIS: | |
| method = "scanpy.tl.dpt" | |
| elif action_type == ActionType.MARKER_SELECTION: | |
| method = "scanpy.tl.rank_genes_groups" | |
| elif action_type == ActionType.VALIDATE_MARKER: | |
| method = "qPCR" | |
| parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"} | |
| elif action_type == ActionType.SYNTHESIZE_CONCLUSION: | |
| parameters = {"claims": []} | |
| return ExperimentAction( | |
| action_type=action_type, | |
| method=method, | |
| parameters=parameters, | |
| confidence=0.75, | |
| ) | |
| def run_episode( | |
| env: BioExperimentEnvironment, | |
| episode_id: str, | |
| policy: str, | |
| max_steps: int | None = None, | |
| ) -> Trajectory: | |
| structured_obs = env.reset() | |
| traj = Trajectory( | |
| episode_id=episode_id, | |
| task=structured_obs.task.model_dump(), | |
| metadata={ | |
| "task_problem": structured_obs.task.problem_statement, | |
| "policy": policy, | |
| }, | |
| ) | |
| done = structured_obs.done | |
| step_num = 0 | |
| while not done: | |
| if max_steps is not None and step_num >= max_steps: | |
| break | |
| history = [rec.action_type for rec in structured_obs.pipeline_history] | |
| action_type = pick_action(policy, step_num, history) | |
| experiment_action = build_experiment_action( | |
| action_type=action_type, | |
| discovered_markers=structured_obs.discovered_markers, | |
| conditions=structured_obs.task.conditions, | |
| ) | |
| structured_obs = env.step(experiment_action) | |
| reward = structured_obs.reward | |
| done = structured_obs.done | |
| step_num += 1 | |
| traj.add_step( | |
| action=experiment_action, | |
| observation=structured_obs, | |
| reward=reward, | |
| done=done, | |
| reward_breakdown=structured_obs.step_reward_breakdown, | |
| ) | |
| print( | |
| f" step={structured_obs.step_index:02d} " | |
| f"action={action_type.value:>28} " | |
| f"reward={reward:+.3f}" | |
| ) | |
| return traj | |
| def main() -> None: | |
| args = parse_args() | |
| if args.seed is not None: | |
| random.seed(args.seed) | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| env = BioExperimentEnvironment() | |
| trajectories: List[Trajectory] = [] | |
| print( | |
| f"Starting rollout collection: episodes={args.episodes}, policy={args.policy}" | |
| ) | |
| for ep in range(args.episodes): | |
| print(f"Episode {ep + 1}/{args.episodes}") | |
| traj = run_episode( | |
| env=env, | |
| episode_id=str(uuid.uuid4()), | |
| policy=args.policy, | |
| max_steps=args.max_steps, | |
| ) | |
| traj.save(out_dir / f"{traj.episode_id}.json") | |
| trajectories.append(traj) | |
| dataset = TrajectoryDataset(trajectories) | |
| stats = EvaluationSuite.online_metrics(trajectories) | |
| print("\nRun complete.") | |
| print(f"Saved trajectories to: {out_dir}") | |
| print("Online metrics:") | |
| for metric in stats: | |
| print(f" - {metric.name}: {metric.value:.4f}") | |
| print(f"Summary: {dataset.summary()}") | |
| if __name__ == "__main__": | |
| main() | |