hackathon / training /rollout_collection.py
Ev3Dev's picture
Upload folder using huggingface_hub
5c3cfae verified
"""Collect trajectories with direct OpenEnv environment access."""
from __future__ import annotations
import argparse
import random
import uuid
from pathlib import Path
from typing import Dict, List, Sequence
from models import ActionType, ExperimentAction
from server.hackathon_environment import BioExperimentEnvironment
from training.evaluation import EvaluationSuite
from training.trajectory import Trajectory, TrajectoryDataset
HEURISTIC_SEQUENCE = [
ActionType.COLLECT_SAMPLE,
ActionType.PREPARE_LIBRARY,
ActionType.SEQUENCE_CELLS,
ActionType.RUN_QC,
ActionType.FILTER_DATA,
ActionType.NORMALIZE_DATA,
ActionType.CLUSTER_CELLS,
ActionType.TRAJECTORY_ANALYSIS,
ActionType.MARKER_SELECTION,
ActionType.SYNTHESIZE_CONCLUSION,
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run rollout episodes and persist trajectories."
)
parser.add_argument("--episodes", type=int, default=10, help="Number of episodes.")
parser.add_argument(
"--policy",
choices=["random", "heuristic"],
default="heuristic",
help="Policy to use for rollouts.",
)
parser.add_argument(
"--max-steps",
type=int,
default=None,
help="Optional hard cutoff per episode (defaults to env limit).",
)
parser.add_argument(
"--output-dir",
default="training/rollouts",
help="Directory for JSON trajectory outputs.",
)
parser.add_argument("--seed", type=int, default=None, help="RNG seed.")
return parser.parse_args()
def heuristic_next_action(history: Sequence[ActionType], step_index: int) -> ActionType:
seen = set(history)
for action in HEURISTIC_SEQUENCE:
if action not in seen:
return action
if step_index >= 2 and ActionType.VALIDATE_MARKER not in seen:
return ActionType.VALIDATE_MARKER
if ActionType.SYNTHESIZE_CONCLUSION in seen:
return ActionType.SYNTHESIZE_CONCLUSION
return ActionType.SYNTHESIZE_CONCLUSION
def pick_action(policy: str, step_index: int, history: Sequence[ActionType]) -> ActionType:
if policy == "random":
return random.choice(list(ActionType))
return heuristic_next_action(history, step_index)
def default_comparison_name(conditions: Sequence[str]) -> str:
normalized = {condition.lower() for condition in conditions}
if {"healthy", "ipf"} <= normalized:
return "IPF_vs_healthy"
if any("treated" in condition for condition in normalized) and any(
"untreated" in condition for condition in normalized
):
return "treated_vs_untreated"
if any("healthy" in condition for condition in normalized):
return "disease_vs_healthy"
return "disease_vs_healthy"
def build_experiment_action(
action_type: ActionType,
discovered_markers: Sequence[str],
conditions: Sequence[str],
) -> ExperimentAction:
method = None
parameters: Dict[str, object] = {}
if action_type == ActionType.COLLECT_SAMPLE:
parameters = {"n_samples": 6}
elif action_type == ActionType.PREPARE_LIBRARY:
method = "10x_chromium"
elif action_type == ActionType.RUN_QC:
method = "scanpy.pp.calculate_qc_metrics"
elif action_type == ActionType.FILTER_DATA:
method = "scanpy.pp.filter_cells"
elif action_type == ActionType.NORMALIZE_DATA:
method = "scanpy.pp.normalize_total"
elif action_type == ActionType.CLUSTER_CELLS:
method = "scanpy.tl.leiden"
elif action_type == ActionType.DIFFERENTIAL_EXPRESSION:
method = "scanpy.tl.rank_genes_groups"
parameters = {"comparison": default_comparison_name(conditions)}
elif action_type == ActionType.TRAJECTORY_ANALYSIS:
method = "scanpy.tl.dpt"
elif action_type == ActionType.MARKER_SELECTION:
method = "scanpy.tl.rank_genes_groups"
elif action_type == ActionType.VALIDATE_MARKER:
method = "qPCR"
parameters = {"marker": discovered_markers[0] if discovered_markers else "SPP1"}
elif action_type == ActionType.SYNTHESIZE_CONCLUSION:
parameters = {"claims": []}
return ExperimentAction(
action_type=action_type,
method=method,
parameters=parameters,
confidence=0.75,
)
def run_episode(
env: BioExperimentEnvironment,
episode_id: str,
policy: str,
max_steps: int | None = None,
) -> Trajectory:
structured_obs = env.reset()
traj = Trajectory(
episode_id=episode_id,
task=structured_obs.task.model_dump(),
metadata={
"task_problem": structured_obs.task.problem_statement,
"policy": policy,
},
)
done = structured_obs.done
step_num = 0
while not done:
if max_steps is not None and step_num >= max_steps:
break
history = [rec.action_type for rec in structured_obs.pipeline_history]
action_type = pick_action(policy, step_num, history)
experiment_action = build_experiment_action(
action_type=action_type,
discovered_markers=structured_obs.discovered_markers,
conditions=structured_obs.task.conditions,
)
structured_obs = env.step(experiment_action)
reward = structured_obs.reward
done = structured_obs.done
step_num += 1
traj.add_step(
action=experiment_action,
observation=structured_obs,
reward=reward,
done=done,
reward_breakdown=structured_obs.step_reward_breakdown,
)
print(
f" step={structured_obs.step_index:02d} "
f"action={action_type.value:>28} "
f"reward={reward:+.3f}"
)
return traj
def main() -> None:
args = parse_args()
if args.seed is not None:
random.seed(args.seed)
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
env = BioExperimentEnvironment()
trajectories: List[Trajectory] = []
print(
f"Starting rollout collection: episodes={args.episodes}, policy={args.policy}"
)
for ep in range(args.episodes):
print(f"Episode {ep + 1}/{args.episodes}")
traj = run_episode(
env=env,
episode_id=str(uuid.uuid4()),
policy=args.policy,
max_steps=args.max_steps,
)
traj.save(out_dir / f"{traj.episode_id}.json")
trajectories.append(traj)
dataset = TrajectoryDataset(trajectories)
stats = EvaluationSuite.online_metrics(trajectories)
print("\nRun complete.")
print(f"Saved trajectories to: {out_dir}")
print("Online metrics:")
for metric in stats:
print(f" - {metric.name}: {metric.value:.4f}")
print(f"Summary: {dataset.summary()}")
if __name__ == "__main__":
main()