Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Train and run quantized self-driving lab models with Unsloth. | |
| This keeps the same OpenEnv prompt + reward wiring as `training_script.py`, | |
| but arranges the Unsloth path in the more typical pattern: | |
| 1. patch GRPO support | |
| 2. load a quantized model | |
| 3. apply LoRA adapters | |
| 4. train with an explicit OpenEnv reward function | |
| NOTE: Unsloth must be imported before trl, transformers, peft. Import this | |
| module before training_script. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Sequence | |
| # Unsloth must be imported before trl/transformers/peft for optimizations. | |
| import unsloth # noqa: F401 | |
| import training_script as base | |
| DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output" | |
| DEFAULT_MAX_SEQ_LENGTH = 2048 | |
| DEFAULT_LORA_R = 16 | |
| DEFAULT_LORA_ALPHA = 16 | |
| DEFAULT_LORA_DROPOUT = 0.0 | |
| LORA_TARGET_MODULES = [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| def require_unsloth(): | |
| try: | |
| from unsloth import FastLanguageModel, PatchFastRL | |
| except ImportError as exc: | |
| msg = str(exc) | |
| if "vllm.lora" in msg or "vllm" in msg.lower(): | |
| raise RuntimeError( | |
| f"Unsloth failed: {exc}. " | |
| "unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n" | |
| " pip install 'vllm==0.8.2' # requires torch 2.6\n" | |
| " pip install 'vllm==0.7.3' # alternative\n" | |
| "If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8." | |
| ) from exc | |
| if "unsloth" in msg.lower(): | |
| raise RuntimeError( | |
| "Unsloth is not installed. Run `uv sync` or `pip install unsloth`." | |
| ) from exc | |
| raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc | |
| return FastLanguageModel, PatchFastRL | |
| def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any): | |
| for optional_key in ("fast_inference", "trust_remote_code"): | |
| try: | |
| return FastLanguageModel.from_pretrained(**kwargs) | |
| except TypeError as exc: | |
| if optional_key in kwargs and optional_key in str(exc): | |
| kwargs = dict(kwargs) | |
| kwargs.pop(optional_key, None) | |
| continue | |
| raise | |
| return FastLanguageModel.from_pretrained(**kwargs) | |
| def build_argument_parser() -> argparse.ArgumentParser: | |
| parser = base.build_argument_parser() | |
| parser.description = ( | |
| "Train a GRPO policy with Unsloth quantized loading for faster H100 runs." | |
| ) | |
| parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR) | |
| parser.add_argument( | |
| "--max-seq-length", | |
| type=int, | |
| default=DEFAULT_MAX_SEQ_LENGTH, | |
| help="Context length passed to Unsloth model loading.", | |
| ) | |
| parser.add_argument( | |
| "--disable-4bit", | |
| action="store_true", | |
| help="Disable 4-bit quantized loading and use the wider base weights.", | |
| ) | |
| parser.add_argument( | |
| "--lora-r", | |
| type=int, | |
| default=DEFAULT_LORA_R, | |
| help="LoRA rank used for the quantized GRPO policy.", | |
| ) | |
| parser.add_argument( | |
| "--lora-alpha", | |
| type=int, | |
| default=DEFAULT_LORA_ALPHA, | |
| help="LoRA alpha used for the quantized GRPO policy.", | |
| ) | |
| parser.add_argument( | |
| "--lora-dropout", | |
| type=float, | |
| default=DEFAULT_LORA_DROPOUT, | |
| help="LoRA dropout used for the quantized GRPO policy.", | |
| ) | |
| parser.add_argument( | |
| "--save-merged-16bit", | |
| action="store_true", | |
| help="Also export a merged 16-bit model after training if supported.", | |
| ) | |
| return parser | |
| def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: | |
| return build_argument_parser().parse_args(argv) | |
| def make_training_args(**overrides: Any) -> argparse.Namespace: | |
| parser = build_argument_parser() | |
| defaults = vars(parser.parse_args([])) | |
| unknown = sorted(set(overrides) - set(defaults)) | |
| if unknown: | |
| raise ValueError(f"Unknown training args: {', '.join(unknown)}") | |
| defaults.update(overrides) | |
| return argparse.Namespace(**defaults) | |
| def load_model_artifacts( | |
| model_id: str, | |
| *, | |
| trust_remote_code: bool, | |
| max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH, | |
| load_in_4bit: bool = True, | |
| fast_inference: bool = False, | |
| prepare_for_inference: bool = False, | |
| ): | |
| FastLanguageModel, _ = require_unsloth() | |
| runtime = base.resolve_torch_runtime() | |
| print(f"Loading Unsloth tokenizer+model for {model_id} ...") | |
| model, tokenizer = _call_unsloth_from_pretrained( | |
| FastLanguageModel, | |
| model_name=model_id, | |
| max_seq_length=max_seq_length, | |
| dtype=runtime["dtype"], | |
| load_in_4bit=load_in_4bit, | |
| fast_inference=fast_inference, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| if tokenizer.pad_token is None and tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if prepare_for_inference: | |
| try: | |
| FastLanguageModel.for_inference(model) | |
| except AttributeError: | |
| pass | |
| device = getattr(model, "device", None) | |
| if device is None: | |
| try: | |
| device = next(model.parameters()).device | |
| except StopIteration: | |
| device = runtime["device"] | |
| print(f"Loaded model on device: {device}") | |
| return tokenizer, model | |
| def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward: | |
| """Return the OpenEnv-compatible reward callable used by GRPO.""" | |
| return base.OpenEnvReward( | |
| reward_backend=args.reward_backend, | |
| base_url=args.base_url, | |
| domain_randomise=args.domain_randomise, | |
| ) | |
| def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]: | |
| """Build the OpenEnv rollout states that seed GRPO prompts.""" | |
| scenario_names = base.selected_scenarios(args.scenario_name) | |
| examples = base.build_prompt_examples( | |
| dataset_episodes=args.dataset_episodes, | |
| rollout_steps=args.rollout_steps, | |
| collection_policy=args.collection_policy, | |
| scenario_names=scenario_names, | |
| seed=args.seed, | |
| domain_randomise=args.domain_randomise, | |
| ) | |
| return { | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| } | |
| def patch_unsloth_grpo(): | |
| """Patch TRL GRPO to use Unsloth's optimized kernels.""" | |
| FastLanguageModel, PatchFastRL = require_unsloth() | |
| PatchFastRL("GRPO", FastLanguageModel) | |
| return FastLanguageModel | |
| def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any: | |
| """Apply LoRA adapters in the usual Unsloth configuration style.""" | |
| return FastLanguageModel.get_peft_model( | |
| model, | |
| r=args.lora_r, | |
| target_modules=LORA_TARGET_MODULES, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| use_gradient_checkpointing=True, | |
| random_state=args.seed, | |
| ) | |
| def build_grpo_config( | |
| args: argparse.Namespace, | |
| runtime: Dict[str, Any], | |
| ): | |
| import inspect | |
| base._guard_invalid_torchao_version() | |
| base._guard_partial_vllm_install() | |
| from trl import GRPOConfig | |
| supported_params = set(inspect.signature(GRPOConfig.__init__).parameters) | |
| config_kwargs = { | |
| "output_dir": args.output_dir, | |
| "learning_rate": args.learning_rate, | |
| "per_device_train_batch_size": args.per_device_train_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| "num_generations": args.num_generations, | |
| "max_completion_length": args.max_completion_length, | |
| "num_train_epochs": args.num_train_epochs, | |
| "logging_steps": args.logging_steps, | |
| "save_steps": args.save_steps, | |
| "bf16": runtime["bf16"], | |
| "fp16": runtime["fp16"], | |
| "report_to": "none", | |
| "remove_unused_columns": False, | |
| } | |
| # Only add max_prompt_length if this TRL version supports it; UnslothGRPOTrainer can | |
| # fail when passing it to parent, so we only pass when explicitly supported. | |
| if "max_prompt_length" in supported_params: | |
| config_kwargs["max_prompt_length"] = None # text-only; avoids image_token_id crash | |
| if ( | |
| "max_length" in supported_params | |
| and "max_prompt_length" not in supported_params | |
| and "max_completion_length" not in supported_params | |
| ): | |
| config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length | |
| filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params} | |
| skipped = sorted(set(config_kwargs) - set(filtered_kwargs)) | |
| if skipped: | |
| print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}") | |
| return GRPOConfig(**filtered_kwargs) | |
| def build_unsloth_grpo_trainer( | |
| *, | |
| model: Any, | |
| tokenizer: Any, | |
| reward_func: Any, | |
| train_dataset: Any, | |
| args: argparse.Namespace, | |
| runtime: Dict[str, Any], | |
| ): | |
| base._guard_invalid_torchao_version() | |
| base._guard_partial_vllm_install() | |
| from trl import GRPOTrainer | |
| config = build_grpo_config(args, runtime) | |
| return GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_func, | |
| args=config, | |
| train_dataset=train_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| def generate_action_with_model( | |
| model: Any, | |
| tokenizer: Any, | |
| prompt_or_observation: str | base.ExperimentObservation, | |
| *, | |
| max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET, | |
| temperature: float = 0.2, | |
| top_p: float = 0.9, | |
| do_sample: bool = True, | |
| ) -> Dict[str, Any]: | |
| import torch | |
| if isinstance(prompt_or_observation, base.ExperimentObservation): | |
| prompt = base.build_training_prompt(prompt_or_observation) | |
| else: | |
| prompt = str(prompt_or_observation) | |
| model_device = getattr(model, "device", None) | |
| if model_device is None: | |
| try: | |
| model_device = next(model.parameters()).device | |
| except StopIteration: | |
| model_device = base.resolve_torch_runtime()["device"] | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {key: value.to(model_device) for key, value in inputs.items()} | |
| prompt_tokens = inputs["input_ids"].shape[1] | |
| generation_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "pad_token_id": tokenizer.pad_token_id, | |
| } | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **generation_kwargs) | |
| new_tokens = output_ids[0][prompt_tokens:] | |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| action = base.parse_action_completion(response_text) | |
| if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation): | |
| action = base.ensure_conclusion_claims(prompt_or_observation, action) | |
| return { | |
| "prompt": prompt, | |
| "response_text": response_text, | |
| "action": action, | |
| } | |
| def run_training(args: argparse.Namespace) -> Dict[str, Any]: | |
| random.seed(args.seed) | |
| runtime = base.resolve_torch_runtime() | |
| if args.load_model_only: | |
| tokenizer, model = load_model_artifacts( | |
| args.model_id, | |
| trust_remote_code=args.trust_remote_code, | |
| max_seq_length=args.max_seq_length, | |
| load_in_4bit=not args.disable_4bit, | |
| fast_inference=False, | |
| prepare_for_inference=True, | |
| ) | |
| device = getattr(model, "device", "unknown") | |
| print(f"Unsloth model ready: {args.model_id}") | |
| print(f"Tokenizer vocab size: {len(tokenizer)}") | |
| print(f"Model device: {device}") | |
| print(f"Runtime device name: {runtime['device_name']}") | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| } | |
| prompt_data = prepare_prompt_examples(args) | |
| scenario_names = prompt_data["scenario_names"] | |
| examples = prompt_data["examples"] | |
| env_reward = build_openenv_reward(args) | |
| if args.dry_run: | |
| base.run_dry_run_preview(examples, env_reward, args.output_dir) | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| "reward_fn": env_reward, | |
| } | |
| from datasets import Dataset | |
| FastLanguageModel = patch_unsloth_grpo() | |
| train_dataset = Dataset.from_list(examples) | |
| # 1. Load model with Unsloth quantized loading. | |
| tokenizer, model = load_model_artifacts( | |
| args.model_id, | |
| trust_remote_code=args.trust_remote_code, | |
| max_seq_length=args.max_seq_length, | |
| load_in_4bit=not args.disable_4bit, | |
| fast_inference=False, | |
| ) | |
| # 2. Apply LoRA adapters. | |
| model = apply_lora_adapters(FastLanguageModel, model, args) | |
| print( | |
| f"Unsloth training runtime: device={runtime['device']} " | |
| f"name={runtime['device_name']} " | |
| f"dtype={runtime['dtype']} " | |
| f"load_in_4bit={not args.disable_4bit}" | |
| ) | |
| print( | |
| "OpenEnv reward: " | |
| f"backend={args.reward_backend} scenarios={len(scenario_names)} " | |
| f"examples={len(examples)}" | |
| ) | |
| # 3. Train with GRPO against the OpenEnv reward function. | |
| trainer = build_unsloth_grpo_trainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| reward_func=env_reward, | |
| train_dataset=train_dataset, | |
| args=args, | |
| runtime=runtime, | |
| ) | |
| # Workaround: UnslothGRPOTrainer expects vision token IDs for max_prompt_length | |
| # truncation; text-only models don't have them. Set to None so protected=[]. | |
| for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"): | |
| if not hasattr(trainer, attr): | |
| setattr(trainer, attr, None) | |
| trainer.train() | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| if args.save_merged_16bit: | |
| merged_dir = Path(args.output_dir) / "merged_16bit" | |
| try: | |
| model.save_pretrained_merged( | |
| str(merged_dir), | |
| tokenizer, | |
| save_method="merged_16bit", | |
| ) | |
| print(f"Saved merged 16-bit model to {merged_dir}") | |
| except AttributeError: | |
| print("Merged 16-bit export is not available in this Unsloth build; skipping.") | |
| if args.push_to_hub: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True) | |
| print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}") | |
| api.upload_folder( | |
| folder_path=args.output_dir, | |
| repo_id=args.push_to_hub, | |
| repo_type="model", | |
| create_pr=False, | |
| ) | |
| print(f"Model pushed to https://huggingface.co/{args.push_to_hub}") | |
| plot_paths = base.save_training_plots( | |
| trainer.state.log_history, | |
| args.output_dir, | |
| metric_key=args.plot_metric_key, | |
| ) | |
| print("Saved training plots:") | |
| for plot_name, plot_path in plot_paths.items(): | |
| print(f" - {plot_name}: {plot_path}") | |
| return { | |
| "args": args, | |
| "runtime": runtime, | |
| "scenario_names": scenario_names, | |
| "examples": examples, | |
| "reward_fn": env_reward, | |
| "train_dataset": train_dataset, | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| "trainer": trainer, | |
| "plot_paths": plot_paths, | |
| } | |
| def main() -> None: | |
| run_training(parse_args()) | |
| if __name__ == "__main__": | |
| main() | |