| | """Minimal Colab entrypoint for Unsloth GRPO against a remote OpenEnv Space.
|
| |
|
| | This keeps the repo's prompt formatting and action parsing logic, but builds
|
| | prompt states by interacting with a deployed OpenEnv Hugging Face Space instead
|
| | of the local in-process environment. That makes the Colab workflow match the
|
| | remote environment users actually want to train against.
|
| | """
|
| |
|
| | from __future__ import annotations
|
| |
|
| | import argparse
|
| | import json
|
| | import random
|
| | from typing import Any, Dict, List, Optional, Sequence
|
| |
|
| | from client import BioExperimentEnv
|
| | import training_script as base
|
| |
|
| | DEFAULT_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
|
| | DEFAULT_OUTPUT_DIR = "artifacts/grpo-unsloth-llama32-3b-space"
|
| | DEFAULT_SPACE_REPO_ID = "Ev3Dev/hackathon"
|
| |
|
| |
|
| | def hf_space_repo_to_base_url(repo_id: str) -> str:
|
| | """Convert `owner/space-name` to the standard `hf.space` URL."""
|
| | owner, space_name = repo_id.split("/", 1)
|
| | normalized_owner = owner.strip().lower().replace("_", "-")
|
| | normalized_space = space_name.strip().lower().replace("_", "-")
|
| | return f"https://{normalized_owner}-{normalized_space}.hf.space"
|
| |
|
| |
|
| | def require_unsloth_base():
|
| |
|
| | import unsloth
|
| | import training_unsloth as unsloth_base
|
| |
|
| | return unsloth_base
|
| |
|
| |
|
| | def build_argument_parser() -> argparse.ArgumentParser:
|
| | parser = argparse.ArgumentParser(
|
| | description="Train Unsloth Llama 3.2 3B on a remote OpenEnv Hugging Face Space."
|
| | )
|
| | parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
|
| | parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
|
| | parser.add_argument("--dataset-episodes", type=int, default=8)
|
| | parser.add_argument("--rollout-steps", type=int, default=6)
|
| | parser.add_argument(
|
| | "--collection-policy",
|
| | choices=["random", "heuristic"],
|
| | default="heuristic",
|
| | )
|
| | parser.add_argument("--base-url", default="")
|
| | parser.add_argument(
|
| | "--space-repo-id",
|
| | default=DEFAULT_SPACE_REPO_ID,
|
| | help="Hugging Face Space repo id, for example `Ev3Dev/hackathon`.",
|
| | )
|
| | parser.add_argument("--num-generations", type=int, default=2)
|
| | parser.add_argument("--max-completion-length", type=int, default=160)
|
| | parser.add_argument("--max-prompt-length", type=int, default=1280)
|
| | parser.add_argument("--max-seq-length", type=int, default=2048)
|
| | parser.add_argument("--per-device-train-batch-size", type=int, default=1)
|
| | parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| | parser.add_argument("--learning-rate", type=float, default=5e-6)
|
| | parser.add_argument("--num-train-epochs", type=float, default=1.0)
|
| | parser.add_argument("--logging-steps", type=int, default=1)
|
| | parser.add_argument("--save-steps", type=int, default=25)
|
| | parser.add_argument("--plot-metric-key", default=None)
|
| | parser.add_argument("--seed", type=int, default=42)
|
| | parser.add_argument("--dry-run", action="store_true")
|
| | parser.add_argument("--load-model-only", action="store_true")
|
| | parser.add_argument("--trust-remote-code", action="store_true")
|
| | parser.add_argument("--disable-4bit", action="store_true")
|
| | parser.add_argument("--lora-r", type=int, default=unsloth_defaults()["lora_r"])
|
| | parser.add_argument(
|
| | "--lora-alpha", type=int, default=unsloth_defaults()["lora_alpha"]
|
| | )
|
| | parser.add_argument(
|
| | "--lora-dropout", type=float, default=unsloth_defaults()["lora_dropout"]
|
| | )
|
| | return parser
|
| |
|
| |
|
| | def unsloth_defaults() -> Dict[str, float]:
|
| | return {
|
| | "lora_r": 16,
|
| | "lora_alpha": 16,
|
| | "lora_dropout": 0.0,
|
| | }
|
| |
|
| |
|
| | def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| | args = build_argument_parser().parse_args(argv)
|
| | if not args.base_url:
|
| | args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
|
| | return args
|
| |
|
| |
|
| | def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| | parser = build_argument_parser()
|
| | defaults = vars(parser.parse_args([]))
|
| | unknown = sorted(set(overrides) - set(defaults))
|
| | if unknown:
|
| | raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| | defaults.update(overrides)
|
| | args = argparse.Namespace(**defaults)
|
| | if not getattr(args, "base_url", ""):
|
| | args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
|
| | return args
|
| |
|
| |
|
| | def build_remote_prompt_examples(args: argparse.Namespace) -> List[Dict[str, str]]:
|
| | """Collect prompt states directly from the remote OpenEnv server."""
|
| | rng = random.Random(args.seed)
|
| | examples: List[Dict[str, str]] = []
|
| |
|
| | for _episode_idx in range(args.dataset_episodes):
|
| | with BioExperimentEnv(base_url=args.base_url) as env:
|
| | result = env.reset()
|
| | obs = result.observation
|
| | history_actions: List[base.ExperimentAction] = []
|
| |
|
| | for step_idx in range(args.rollout_steps):
|
| | if obs.done:
|
| | break
|
| |
|
| | next_action = base.build_experiment_action(
|
| | action_type=base.pick_action(
|
| | args.collection_policy,
|
| | step_idx,
|
| | [action.action_type for action in history_actions],
|
| | ),
|
| | discovered_markers=obs.discovered_markers,
|
| | candidate_mechanisms=obs.candidate_mechanisms,
|
| | conditions=obs.task.conditions,
|
| | )
|
| | examples.append(
|
| | {
|
| | "prompt": base.build_training_prompt(obs),
|
| | "history_actions": json.dumps(
|
| | [action.model_dump() for action in history_actions]
|
| | ),
|
| | "reference_action": base.action_completion_json(next_action),
|
| | "problem_statement": obs.task.problem_statement,
|
| | "episode_tag": f"remote-{rng.randrange(10**9):09d}",
|
| | }
|
| | )
|
| |
|
| | history_actions.append(next_action)
|
| | result = env.step(next_action)
|
| | obs = result.observation
|
| | if result.done:
|
| | break
|
| |
|
| | return examples
|
| |
|
| |
|
| | class RemoteSpaceReward:
|
| | """Reward function that replays each candidate against the remote Space."""
|
| |
|
| | def __init__(
|
| | self,
|
| | *,
|
| | base_url: str,
|
| | invalid_action_penalty: float = base.INVALID_ACTION_PENALTY,
|
| | environment_error_penalty: float = base.ENVIRONMENT_ERROR_PENALTY,
|
| | ) -> None:
|
| | self.__name__ = "remote_space_reward"
|
| | self.base_url = base_url
|
| | self.invalid_action_penalty = invalid_action_penalty
|
| | self.environment_error_penalty = environment_error_penalty
|
| |
|
| | def __call__(
|
| | self,
|
| | completions: List[Any],
|
| | history_actions: Optional[List[str]] = None,
|
| | **_: Any,
|
| | ) -> List[float]:
|
| | history_columns = base.normalise_column(history_actions, len(completions))
|
| | rewards: List[float] = []
|
| |
|
| | for completion, current_history in zip(completions, history_columns):
|
| | action = base.parse_action_completion(base.completion_to_text(completion))
|
| | if action is None:
|
| | rewards.append(self.invalid_action_penalty)
|
| | continue
|
| |
|
| | try:
|
| | rewards.append(self._score_remote(action, current_history))
|
| | except Exception:
|
| | rewards.append(self.environment_error_penalty)
|
| |
|
| | return rewards
|
| |
|
| | def _score_remote(
|
| | self,
|
| | action: base.ExperimentAction,
|
| | history_actions: Optional[str],
|
| | ) -> float:
|
| | with BioExperimentEnv(base_url=self.base_url) as env:
|
| | result = env.reset()
|
| | obs = result.observation
|
| |
|
| | for previous_action in base.decode_history_actions(history_actions):
|
| | result = env.step(previous_action)
|
| | obs = result.observation
|
| | if result.done:
|
| | return float(result.reward or obs.reward or 0.0)
|
| |
|
| | action = base.ensure_conclusion_claims(obs, action)
|
| | result = env.step(action)
|
| | if result.reward is not None:
|
| | return float(result.reward)
|
| | return float(result.observation.reward)
|
| |
|
| |
|
| | def run_dry_run_preview(
|
| | examples: Sequence[Dict[str, str]],
|
| | reward_fn: RemoteSpaceReward,
|
| | output_dir: str,
|
| | base_url: str,
|
| | ) -> None:
|
| | if not examples:
|
| | raise ValueError("No training prompts were generated for the dry run.")
|
| |
|
| | sample = examples[0]
|
| | sample_reward = reward_fn(
|
| | completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
|
| | history_actions=[sample["history_actions"]],
|
| | )[0]
|
| |
|
| | print(f"Built {len(examples)} remote prompt states.")
|
| | print(f"Remote OpenEnv Space: {base_url}")
|
| | print(f"Output directory: {output_dir}")
|
| | print(f"Sample reward for reference action: {sample_reward:+.3f}")
|
| | print("\nSample prompt:\n")
|
| | print(sample["prompt"])
|
| |
|
| |
|
| | def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| | random.seed(args.seed)
|
| | runtime = base.resolve_torch_runtime()
|
| | unsloth_base = require_unsloth_base()
|
| |
|
| | if args.load_model_only:
|
| | tokenizer, model = unsloth_base.load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | max_seq_length=args.max_seq_length,
|
| | load_in_4bit=not args.disable_4bit,
|
| | fast_inference=False,
|
| | prepare_for_inference=True,
|
| | )
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | }
|
| |
|
| | examples = build_remote_prompt_examples(args)
|
| | reward_fn = RemoteSpaceReward(base_url=args.base_url)
|
| |
|
| | if args.dry_run:
|
| | run_dry_run_preview(examples, reward_fn, args.output_dir, args.base_url)
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "examples": examples,
|
| | "reward_fn": reward_fn,
|
| | }
|
| |
|
| | from datasets import Dataset
|
| |
|
| | FastLanguageModel = unsloth_base.patch_unsloth_grpo()
|
| | train_dataset = Dataset.from_list(examples)
|
| |
|
| | tokenizer, model = unsloth_base.load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | max_seq_length=args.max_seq_length,
|
| | load_in_4bit=not args.disable_4bit,
|
| | fast_inference=False,
|
| | )
|
| | model = unsloth_base.apply_lora_adapters(FastLanguageModel, model, args)
|
| |
|
| | print(
|
| | f"Training runtime: device={runtime['device']} "
|
| | f"name={runtime['device_name']} "
|
| | f"dtype={runtime['dtype']} "
|
| | f"load_in_4bit={not args.disable_4bit}"
|
| | )
|
| | print(f"Remote OpenEnv Space: {args.base_url}")
|
| | print(f"Collected remote prompt states: {len(examples)}")
|
| |
|
| | trainer = unsloth_base.build_unsloth_grpo_trainer(
|
| | model=model,
|
| | tokenizer=tokenizer,
|
| | reward_func=reward_fn,
|
| | train_dataset=train_dataset,
|
| | args=args,
|
| | runtime=runtime,
|
| | )
|
| | for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
|
| | if not hasattr(trainer, attr):
|
| | setattr(trainer, attr, None)
|
| |
|
| | trainer.train()
|
| | trainer.save_model(args.output_dir)
|
| | tokenizer.save_pretrained(args.output_dir)
|
| |
|
| | plot_paths = base.save_training_plots(
|
| | trainer.state.log_history,
|
| | args.output_dir,
|
| | metric_key=args.plot_metric_key,
|
| | )
|
| | print("Saved training plots:")
|
| | for plot_name, plot_path in plot_paths.items():
|
| | print(f" - {plot_name}: {plot_path}")
|
| |
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "examples": examples,
|
| | "reward_fn": reward_fn,
|
| | "train_dataset": train_dataset,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | "trainer": trainer,
|
| | "plot_paths": plot_paths,
|
| | }
|
| |
|
| |
|
| | def main() -> None:
|
| | run_training(parse_args())
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|