hackathon / colab_train_llama32_remote.py
Ev3Dev's picture
Upload folder using huggingface_hub
2bf5069 verified
"""Minimal Colab entrypoint for Unsloth GRPO against a remote OpenEnv Space.
This keeps the repo's prompt formatting and action parsing logic, but builds
prompt states by interacting with a deployed OpenEnv Hugging Face Space instead
of the local in-process environment. That makes the Colab workflow match the
remote environment users actually want to train against.
"""
from __future__ import annotations
import argparse
import json
import random
from typing import Any, Dict, List, Optional, Sequence
from client import BioExperimentEnv
import training_script as base
DEFAULT_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
DEFAULT_OUTPUT_DIR = "artifacts/grpo-unsloth-llama32-3b-space"
DEFAULT_SPACE_REPO_ID = "Ev3Dev/hackathon"
def hf_space_repo_to_base_url(repo_id: str) -> str:
"""Convert `owner/space-name` to the standard `hf.space` URL."""
owner, space_name = repo_id.split("/", 1)
normalized_owner = owner.strip().lower().replace("_", "-")
normalized_space = space_name.strip().lower().replace("_", "-")
return f"https://{normalized_owner}-{normalized_space}.hf.space"
def require_unsloth_base():
# Unsloth must be imported before trl / transformers / peft.
import unsloth # noqa: F401
import training_unsloth as unsloth_base
return unsloth_base
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Train Unsloth Llama 3.2 3B on a remote OpenEnv Hugging Face Space."
)
parser.add_argument("--model-id", default=DEFAULT_MODEL_ID)
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
parser.add_argument("--dataset-episodes", type=int, default=8)
parser.add_argument("--rollout-steps", type=int, default=6)
parser.add_argument(
"--collection-policy",
choices=["random", "heuristic"],
default="heuristic",
)
parser.add_argument("--base-url", default="")
parser.add_argument(
"--space-repo-id",
default=DEFAULT_SPACE_REPO_ID,
help="Hugging Face Space repo id, for example `Ev3Dev/hackathon`.",
)
parser.add_argument("--num-generations", type=int, default=2)
parser.add_argument("--max-completion-length", type=int, default=160)
parser.add_argument("--max-prompt-length", type=int, default=1280)
parser.add_argument("--max-seq-length", type=int, default=2048)
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--num-train-epochs", type=float, default=1.0)
parser.add_argument("--logging-steps", type=int, default=1)
parser.add_argument("--save-steps", type=int, default=25)
parser.add_argument("--plot-metric-key", default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--load-model-only", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--disable-4bit", action="store_true")
parser.add_argument("--lora-r", type=int, default=unsloth_defaults()["lora_r"])
parser.add_argument(
"--lora-alpha", type=int, default=unsloth_defaults()["lora_alpha"]
)
parser.add_argument(
"--lora-dropout", type=float, default=unsloth_defaults()["lora_dropout"]
)
return parser
def unsloth_defaults() -> Dict[str, float]:
return {
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
args = build_argument_parser().parse_args(argv)
if not args.base_url:
args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
return args
def make_training_args(**overrides: Any) -> argparse.Namespace:
parser = build_argument_parser()
defaults = vars(parser.parse_args([]))
unknown = sorted(set(overrides) - set(defaults))
if unknown:
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
defaults.update(overrides)
args = argparse.Namespace(**defaults)
if not getattr(args, "base_url", ""):
args.base_url = hf_space_repo_to_base_url(args.space_repo_id)
return args
def build_remote_prompt_examples(args: argparse.Namespace) -> List[Dict[str, str]]:
"""Collect prompt states directly from the remote OpenEnv server."""
rng = random.Random(args.seed)
examples: List[Dict[str, str]] = []
for _episode_idx in range(args.dataset_episodes):
with BioExperimentEnv(base_url=args.base_url) as env:
result = env.reset()
obs = result.observation
history_actions: List[base.ExperimentAction] = []
for step_idx in range(args.rollout_steps):
if obs.done:
break
next_action = base.build_experiment_action(
action_type=base.pick_action(
args.collection_policy,
step_idx,
[action.action_type for action in history_actions],
),
discovered_markers=obs.discovered_markers,
candidate_mechanisms=obs.candidate_mechanisms,
conditions=obs.task.conditions,
)
examples.append(
{
"prompt": base.build_training_prompt(obs),
"history_actions": json.dumps(
[action.model_dump() for action in history_actions]
),
"reference_action": base.action_completion_json(next_action),
"problem_statement": obs.task.problem_statement,
"episode_tag": f"remote-{rng.randrange(10**9):09d}",
}
)
history_actions.append(next_action)
result = env.step(next_action)
obs = result.observation
if result.done:
break
return examples
class RemoteSpaceReward:
"""Reward function that replays each candidate against the remote Space."""
def __init__(
self,
*,
base_url: str,
invalid_action_penalty: float = base.INVALID_ACTION_PENALTY,
environment_error_penalty: float = base.ENVIRONMENT_ERROR_PENALTY,
) -> None:
self.__name__ = "remote_space_reward"
self.base_url = base_url
self.invalid_action_penalty = invalid_action_penalty
self.environment_error_penalty = environment_error_penalty
def __call__(
self,
completions: List[Any],
history_actions: Optional[List[str]] = None,
**_: Any,
) -> List[float]:
history_columns = base.normalise_column(history_actions, len(completions))
rewards: List[float] = []
for completion, current_history in zip(completions, history_columns):
action = base.parse_action_completion(base.completion_to_text(completion))
if action is None:
rewards.append(self.invalid_action_penalty)
continue
try:
rewards.append(self._score_remote(action, current_history))
except Exception:
rewards.append(self.environment_error_penalty)
return rewards
def _score_remote(
self,
action: base.ExperimentAction,
history_actions: Optional[str],
) -> float:
with BioExperimentEnv(base_url=self.base_url) as env:
result = env.reset()
obs = result.observation
for previous_action in base.decode_history_actions(history_actions):
result = env.step(previous_action)
obs = result.observation
if result.done:
return float(result.reward or obs.reward or 0.0)
action = base.ensure_conclusion_claims(obs, action)
result = env.step(action)
if result.reward is not None:
return float(result.reward)
return float(result.observation.reward)
def run_dry_run_preview(
examples: Sequence[Dict[str, str]],
reward_fn: RemoteSpaceReward,
output_dir: str,
base_url: str,
) -> None:
if not examples:
raise ValueError("No training prompts were generated for the dry run.")
sample = examples[0]
sample_reward = reward_fn(
completions=[[{"role": "assistant", "content": sample["reference_action"]}]],
history_actions=[sample["history_actions"]],
)[0]
print(f"Built {len(examples)} remote prompt states.")
print(f"Remote OpenEnv Space: {base_url}")
print(f"Output directory: {output_dir}")
print(f"Sample reward for reference action: {sample_reward:+.3f}")
print("\nSample prompt:\n")
print(sample["prompt"])
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
random.seed(args.seed)
runtime = base.resolve_torch_runtime()
unsloth_base = require_unsloth_base()
if args.load_model_only:
tokenizer, model = unsloth_base.load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
max_seq_length=args.max_seq_length,
load_in_4bit=not args.disable_4bit,
fast_inference=False,
prepare_for_inference=True,
)
return {
"args": args,
"runtime": runtime,
"tokenizer": tokenizer,
"model": model,
}
examples = build_remote_prompt_examples(args)
reward_fn = RemoteSpaceReward(base_url=args.base_url)
if args.dry_run:
run_dry_run_preview(examples, reward_fn, args.output_dir, args.base_url)
return {
"args": args,
"runtime": runtime,
"examples": examples,
"reward_fn": reward_fn,
}
from datasets import Dataset
FastLanguageModel = unsloth_base.patch_unsloth_grpo()
train_dataset = Dataset.from_list(examples)
tokenizer, model = unsloth_base.load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
max_seq_length=args.max_seq_length,
load_in_4bit=not args.disable_4bit,
fast_inference=False,
)
model = unsloth_base.apply_lora_adapters(FastLanguageModel, model, args)
print(
f"Training runtime: device={runtime['device']} "
f"name={runtime['device_name']} "
f"dtype={runtime['dtype']} "
f"load_in_4bit={not args.disable_4bit}"
)
print(f"Remote OpenEnv Space: {args.base_url}")
print(f"Collected remote prompt states: {len(examples)}")
trainer = unsloth_base.build_unsloth_grpo_trainer(
model=model,
tokenizer=tokenizer,
reward_func=reward_fn,
train_dataset=train_dataset,
args=args,
runtime=runtime,
)
for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
if not hasattr(trainer, attr):
setattr(trainer, attr, None)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
plot_paths = base.save_training_plots(
trainer.state.log_history,
args.output_dir,
metric_key=args.plot_metric_key,
)
print("Saved training plots:")
for plot_name, plot_path in plot_paths.items():
print(f" - {plot_name}: {plot_path}")
return {
"args": args,
"runtime": runtime,
"examples": examples,
"reward_fn": reward_fn,
"train_dataset": train_dataset,
"tokenizer": tokenizer,
"model": model,
"trainer": trainer,
"plot_paths": plot_paths,
}
def main() -> None:
run_training(parse_args())
if __name__ == "__main__":
main()