| | """Train and run quantized self-driving lab models with Unsloth.
|
| |
|
| | This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,
|
| | but arranges the Unsloth path in the more typical pattern:
|
| | 1. patch GRPO support
|
| | 2. load a quantized model
|
| | 3. apply LoRA adapters
|
| | 4. train with an explicit OpenEnv reward function
|
| |
|
| | NOTE: Unsloth must be imported before trl, transformers, peft. Import this
|
| | module before training_script.
|
| | """
|
| |
|
| | from __future__ import annotations
|
| |
|
| | import argparse
|
| | import random
|
| | from pathlib import Path
|
| | from typing import Any, Dict, Optional, Sequence
|
| |
|
| |
|
| | import unsloth
|
| |
|
| | import training_script as base
|
| |
|
| | DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
|
| | DEFAULT_MAX_SEQ_LENGTH = 2048
|
| | DEFAULT_LORA_R = 16
|
| | DEFAULT_LORA_ALPHA = 16
|
| | DEFAULT_LORA_DROPOUT = 0.0
|
| | LORA_TARGET_MODULES = [
|
| | "q_proj",
|
| | "k_proj",
|
| | "v_proj",
|
| | "o_proj",
|
| | "gate_proj",
|
| | "up_proj",
|
| | "down_proj",
|
| | ]
|
| |
|
| |
|
| | def require_unsloth():
|
| | try:
|
| | from unsloth import FastLanguageModel, PatchFastRL
|
| | except ImportError as exc:
|
| | msg = str(exc)
|
| | if "vllm.lora" in msg or "vllm" in msg.lower():
|
| | raise RuntimeError(
|
| | f"Unsloth failed: {exc}. "
|
| | "unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
|
| | " pip install 'vllm==0.8.2' # requires torch 2.6\n"
|
| | " pip install 'vllm==0.7.3' # alternative\n"
|
| | "If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
|
| | ) from exc
|
| | if "unsloth" in msg.lower():
|
| | raise RuntimeError(
|
| | "Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
|
| | ) from exc
|
| | raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
|
| | return FastLanguageModel, PatchFastRL
|
| |
|
| |
|
| | def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
|
| | for optional_key in ("fast_inference", "trust_remote_code"):
|
| | try:
|
| | return FastLanguageModel.from_pretrained(**kwargs)
|
| | except TypeError as exc:
|
| | if optional_key in kwargs and optional_key in str(exc):
|
| | kwargs = dict(kwargs)
|
| | kwargs.pop(optional_key, None)
|
| | continue
|
| | raise
|
| | return FastLanguageModel.from_pretrained(**kwargs)
|
| |
|
| |
|
| | def build_argument_parser() -> argparse.ArgumentParser:
|
| | parser = base.build_argument_parser()
|
| | parser.description = (
|
| | "Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
|
| | )
|
| | parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
|
| | parser.add_argument(
|
| | "--max-seq-length",
|
| | type=int,
|
| | default=DEFAULT_MAX_SEQ_LENGTH,
|
| | help="Context length passed to Unsloth model loading.",
|
| | )
|
| | parser.add_argument(
|
| | "--disable-4bit",
|
| | action="store_true",
|
| | help="Disable 4-bit quantized loading and use the wider base weights.",
|
| | )
|
| | parser.add_argument(
|
| | "--lora-r",
|
| | type=int,
|
| | default=DEFAULT_LORA_R,
|
| | help="LoRA rank used for the quantized GRPO policy.",
|
| | )
|
| | parser.add_argument(
|
| | "--lora-alpha",
|
| | type=int,
|
| | default=DEFAULT_LORA_ALPHA,
|
| | help="LoRA alpha used for the quantized GRPO policy.",
|
| | )
|
| | parser.add_argument(
|
| | "--lora-dropout",
|
| | type=float,
|
| | default=DEFAULT_LORA_DROPOUT,
|
| | help="LoRA dropout used for the quantized GRPO policy.",
|
| | )
|
| | parser.add_argument(
|
| | "--save-merged-16bit",
|
| | action="store_true",
|
| | help="Also export a merged 16-bit model after training if supported.",
|
| | )
|
| | return parser
|
| |
|
| |
|
| | def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
| | return build_argument_parser().parse_args(argv)
|
| |
|
| |
|
| | def make_training_args(**overrides: Any) -> argparse.Namespace:
|
| | parser = build_argument_parser()
|
| | defaults = vars(parser.parse_args([]))
|
| | unknown = sorted(set(overrides) - set(defaults))
|
| | if unknown:
|
| | raise ValueError(f"Unknown training args: {', '.join(unknown)}")
|
| | defaults.update(overrides)
|
| | return argparse.Namespace(**defaults)
|
| |
|
| |
|
| | def load_model_artifacts(
|
| | model_id: str,
|
| | *,
|
| | trust_remote_code: bool,
|
| | max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
|
| | load_in_4bit: bool = True,
|
| | fast_inference: bool = False,
|
| | prepare_for_inference: bool = False,
|
| | ):
|
| | FastLanguageModel, _ = require_unsloth()
|
| | runtime = base.resolve_torch_runtime()
|
| |
|
| | print(f"Loading Unsloth tokenizer+model for {model_id} ...")
|
| | model, tokenizer = _call_unsloth_from_pretrained(
|
| | FastLanguageModel,
|
| | model_name=model_id,
|
| | max_seq_length=max_seq_length,
|
| | dtype=runtime["dtype"],
|
| | load_in_4bit=load_in_4bit,
|
| | fast_inference=fast_inference,
|
| | trust_remote_code=trust_remote_code,
|
| | )
|
| | if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| | if prepare_for_inference:
|
| | try:
|
| | FastLanguageModel.for_inference(model)
|
| | except AttributeError:
|
| | pass
|
| |
|
| | device = getattr(model, "device", None)
|
| | if device is None:
|
| | try:
|
| | device = next(model.parameters()).device
|
| | except StopIteration:
|
| | device = runtime["device"]
|
| | print(f"Loaded model on device: {device}")
|
| | return tokenizer, model
|
| |
|
| |
|
| | def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
|
| | """Return the OpenEnv-compatible reward callable used by GRPO."""
|
| | return base.OpenEnvReward(
|
| | reward_backend=args.reward_backend,
|
| | base_url=args.base_url,
|
| | domain_randomise=args.domain_randomise,
|
| | )
|
| |
|
| |
|
| | def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
|
| | """Build the OpenEnv rollout states that seed GRPO prompts."""
|
| | scenario_names = base.selected_scenarios(args.scenario_name)
|
| | examples = base.build_prompt_examples(
|
| | dataset_episodes=args.dataset_episodes,
|
| | rollout_steps=args.rollout_steps,
|
| | collection_policy=args.collection_policy,
|
| | scenario_names=scenario_names,
|
| | seed=args.seed,
|
| | domain_randomise=args.domain_randomise,
|
| | )
|
| | return {
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | }
|
| |
|
| |
|
| | def patch_unsloth_grpo():
|
| | """Patch TRL GRPO to use Unsloth's optimized kernels."""
|
| | FastLanguageModel, PatchFastRL = require_unsloth()
|
| | PatchFastRL("GRPO", FastLanguageModel)
|
| | return FastLanguageModel
|
| |
|
| |
|
| | def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
|
| | """Apply LoRA adapters in the usual Unsloth configuration style."""
|
| | return FastLanguageModel.get_peft_model(
|
| | model,
|
| | r=args.lora_r,
|
| | target_modules=LORA_TARGET_MODULES,
|
| | lora_alpha=args.lora_alpha,
|
| | lora_dropout=args.lora_dropout,
|
| | bias="none",
|
| | use_gradient_checkpointing=True,
|
| | random_state=args.seed,
|
| | )
|
| |
|
| |
|
| | def build_grpo_config(
|
| | args: argparse.Namespace,
|
| | runtime: Dict[str, Any],
|
| | ):
|
| | import inspect
|
| |
|
| | base._guard_invalid_torchao_version()
|
| | base._guard_partial_vllm_install()
|
| | from trl import GRPOConfig
|
| |
|
| | supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
|
| | config_kwargs = {
|
| | "output_dir": args.output_dir,
|
| | "learning_rate": args.learning_rate,
|
| | "per_device_train_batch_size": args.per_device_train_batch_size,
|
| | "gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| | "num_generations": args.num_generations,
|
| | "max_completion_length": args.max_completion_length,
|
| | "num_train_epochs": args.num_train_epochs,
|
| | "logging_steps": args.logging_steps,
|
| | "save_steps": args.save_steps,
|
| | "bf16": runtime["bf16"],
|
| | "fp16": runtime["fp16"],
|
| | "report_to": "none",
|
| | "remove_unused_columns": False,
|
| | }
|
| |
|
| |
|
| | if "max_prompt_length" in supported_params:
|
| | config_kwargs["max_prompt_length"] = None
|
| | if (
|
| | "max_length" in supported_params
|
| | and "max_prompt_length" not in supported_params
|
| | and "max_completion_length" not in supported_params
|
| | ):
|
| | config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
|
| | filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
|
| | skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
|
| | if skipped:
|
| | print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
|
| | return GRPOConfig(**filtered_kwargs)
|
| |
|
| |
|
| | def build_unsloth_grpo_trainer(
|
| | *,
|
| | model: Any,
|
| | tokenizer: Any,
|
| | reward_func: Any,
|
| | train_dataset: Any,
|
| | args: argparse.Namespace,
|
| | runtime: Dict[str, Any],
|
| | ):
|
| | base._guard_invalid_torchao_version()
|
| | base._guard_partial_vllm_install()
|
| | from trl import GRPOTrainer
|
| |
|
| | config = build_grpo_config(args, runtime)
|
| | return GRPOTrainer(
|
| | model=model,
|
| | reward_funcs=reward_func,
|
| | args=config,
|
| | train_dataset=train_dataset,
|
| | processing_class=tokenizer,
|
| | )
|
| |
|
| |
|
| | def generate_action_with_model(
|
| | model: Any,
|
| | tokenizer: Any,
|
| | prompt_or_observation: str | base.ExperimentObservation,
|
| | *,
|
| | max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,
|
| | temperature: float = 0.2,
|
| | top_p: float = 0.9,
|
| | do_sample: bool = True,
|
| | ) -> Dict[str, Any]:
|
| | import torch
|
| |
|
| | if isinstance(prompt_or_observation, base.ExperimentObservation):
|
| | prompt = base.build_training_prompt(prompt_or_observation)
|
| | else:
|
| | prompt = str(prompt_or_observation)
|
| |
|
| | model_device = getattr(model, "device", None)
|
| | if model_device is None:
|
| | try:
|
| | model_device = next(model.parameters()).device
|
| | except StopIteration:
|
| | model_device = base.resolve_torch_runtime()["device"]
|
| |
|
| | inputs = tokenizer(prompt, return_tensors="pt")
|
| | inputs = {key: value.to(model_device) for key, value in inputs.items()}
|
| | prompt_tokens = inputs["input_ids"].shape[1]
|
| |
|
| | generation_kwargs = {
|
| | "max_new_tokens": max_new_tokens,
|
| | "do_sample": do_sample,
|
| | "temperature": temperature,
|
| | "top_p": top_p,
|
| | "pad_token_id": tokenizer.pad_token_id,
|
| | }
|
| | with torch.no_grad():
|
| | output_ids = model.generate(**inputs, **generation_kwargs)
|
| |
|
| | new_tokens = output_ids[0][prompt_tokens:]
|
| | response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| | action = base.parse_action_completion(response_text)
|
| | if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
|
| | action = base.ensure_conclusion_claims(prompt_or_observation, action)
|
| | return {
|
| | "prompt": prompt,
|
| | "response_text": response_text,
|
| | "action": action,
|
| | }
|
| |
|
| |
|
| | def run_training(args: argparse.Namespace) -> Dict[str, Any]:
|
| | random.seed(args.seed)
|
| | runtime = base.resolve_torch_runtime()
|
| |
|
| | if args.load_model_only:
|
| | tokenizer, model = load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | max_seq_length=args.max_seq_length,
|
| | load_in_4bit=not args.disable_4bit,
|
| | fast_inference=False,
|
| | prepare_for_inference=True,
|
| | )
|
| | device = getattr(model, "device", "unknown")
|
| | print(f"Unsloth model ready: {args.model_id}")
|
| | print(f"Tokenizer vocab size: {len(tokenizer)}")
|
| | print(f"Model device: {device}")
|
| | print(f"Runtime device name: {runtime['device_name']}")
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | }
|
| |
|
| | prompt_data = prepare_prompt_examples(args)
|
| | scenario_names = prompt_data["scenario_names"]
|
| | examples = prompt_data["examples"]
|
| | env_reward = build_openenv_reward(args)
|
| |
|
| | if args.dry_run:
|
| | base.run_dry_run_preview(examples, env_reward, args.output_dir)
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | "reward_fn": env_reward,
|
| | }
|
| |
|
| | from datasets import Dataset
|
| |
|
| | FastLanguageModel = patch_unsloth_grpo()
|
| | train_dataset = Dataset.from_list(examples)
|
| |
|
| |
|
| | tokenizer, model = load_model_artifacts(
|
| | args.model_id,
|
| | trust_remote_code=args.trust_remote_code,
|
| | max_seq_length=args.max_seq_length,
|
| | load_in_4bit=not args.disable_4bit,
|
| | fast_inference=False,
|
| | )
|
| |
|
| | model = apply_lora_adapters(FastLanguageModel, model, args)
|
| |
|
| | print(
|
| | f"Unsloth training runtime: device={runtime['device']} "
|
| | f"name={runtime['device_name']} "
|
| | f"dtype={runtime['dtype']} "
|
| | f"load_in_4bit={not args.disable_4bit}"
|
| | )
|
| | print(
|
| | "OpenEnv reward: "
|
| | f"backend={args.reward_backend} scenarios={len(scenario_names)} "
|
| | f"examples={len(examples)}"
|
| | )
|
| |
|
| |
|
| | trainer = build_unsloth_grpo_trainer(
|
| | model=model,
|
| | tokenizer=tokenizer,
|
| | reward_func=env_reward,
|
| | train_dataset=train_dataset,
|
| | args=args,
|
| | runtime=runtime,
|
| | )
|
| |
|
| |
|
| | for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
|
| | if not hasattr(trainer, attr):
|
| | setattr(trainer, attr, None)
|
| | trainer.train()
|
| | trainer.save_model(args.output_dir)
|
| | tokenizer.save_pretrained(args.output_dir)
|
| |
|
| | if args.save_merged_16bit:
|
| | merged_dir = Path(args.output_dir) / "merged_16bit"
|
| | try:
|
| | model.save_pretrained_merged(
|
| | str(merged_dir),
|
| | tokenizer,
|
| | save_method="merged_16bit",
|
| | )
|
| | print(f"Saved merged 16-bit model to {merged_dir}")
|
| | except AttributeError:
|
| | print("Merged 16-bit export is not available in this Unsloth build; skipping.")
|
| |
|
| | if args.push_to_hub:
|
| | from huggingface_hub import HfApi
|
| |
|
| | api = HfApi()
|
| | api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
|
| | print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
|
| | api.upload_folder(
|
| | folder_path=args.output_dir,
|
| | repo_id=args.push_to_hub,
|
| | repo_type="model",
|
| | create_pr=False,
|
| | )
|
| | print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
|
| |
|
| | plot_paths = base.save_training_plots(
|
| | trainer.state.log_history,
|
| | args.output_dir,
|
| | metric_key=args.plot_metric_key,
|
| | )
|
| | print("Saved training plots:")
|
| | for plot_name, plot_path in plot_paths.items():
|
| | print(f" - {plot_name}: {plot_path}")
|
| |
|
| | return {
|
| | "args": args,
|
| | "runtime": runtime,
|
| | "scenario_names": scenario_names,
|
| | "examples": examples,
|
| | "reward_fn": env_reward,
|
| | "train_dataset": train_dataset,
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | "trainer": trainer,
|
| | "plot_paths": plot_paths,
|
| | }
|
| |
|
| |
|
| | def main() -> None:
|
| | run_training(parse_args())
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|