bio-experiment / training_unsloth.py
Ev3Dev's picture
Upload folder using huggingface_hub
ad39f2a verified
"""Train and run quantized self-driving lab models with Unsloth.
This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,
but arranges the Unsloth path in the more typical pattern:
1. patch GRPO support
2. load a quantized model
3. apply LoRA adapters
4. train with an explicit OpenEnv reward function
NOTE: Unsloth must be imported before trl, transformers, peft. Import this
module before training_script.
"""
from __future__ import annotations
import argparse
import random
from pathlib import Path
from typing import Any, Dict, Optional, Sequence
# Unsloth must be imported before trl/transformers/peft for optimizations.
import unsloth # noqa: F401
import training_script as base
DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
DEFAULT_MAX_SEQ_LENGTH = 2048
DEFAULT_LORA_R = 16
DEFAULT_LORA_ALPHA = 16
DEFAULT_LORA_DROPOUT = 0.0
LORA_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
def require_unsloth():
try:
from unsloth import FastLanguageModel, PatchFastRL
except ImportError as exc:
msg = str(exc)
if "vllm.lora" in msg or "vllm" in msg.lower():
raise RuntimeError(
f"Unsloth failed: {exc}. "
"unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
" pip install 'vllm==0.8.2' # requires torch 2.6\n"
" pip install 'vllm==0.7.3' # alternative\n"
"If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
) from exc
if "unsloth" in msg.lower():
raise RuntimeError(
"Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
) from exc
raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
return FastLanguageModel, PatchFastRL
def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
for optional_key in ("fast_inference", "trust_remote_code"):
try:
return FastLanguageModel.from_pretrained(**kwargs)
except TypeError as exc:
if optional_key in kwargs and optional_key in str(exc):
kwargs = dict(kwargs)
kwargs.pop(optional_key, None)
continue
raise
return FastLanguageModel.from_pretrained(**kwargs)
def build_argument_parser() -> argparse.ArgumentParser:
parser = base.build_argument_parser()
parser.description = (
"Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
)
parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
parser.add_argument(
"--max-seq-length",
type=int,
default=DEFAULT_MAX_SEQ_LENGTH,
help="Context length passed to Unsloth model loading.",
)
parser.add_argument(
"--disable-4bit",
action="store_true",
help="Disable 4-bit quantized loading and use the wider base weights.",
)
parser.add_argument(
"--lora-r",
type=int,
default=DEFAULT_LORA_R,
help="LoRA rank used for the quantized GRPO policy.",
)
parser.add_argument(
"--lora-alpha",
type=int,
default=DEFAULT_LORA_ALPHA,
help="LoRA alpha used for the quantized GRPO policy.",
)
parser.add_argument(
"--lora-dropout",
type=float,
default=DEFAULT_LORA_DROPOUT,
help="LoRA dropout used for the quantized GRPO policy.",
)
parser.add_argument(
"--save-merged-16bit",
action="store_true",
help="Also export a merged 16-bit model after training if supported.",
)
return parser
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
return build_argument_parser().parse_args(argv)
def make_training_args(**overrides: Any) -> argparse.Namespace:
parser = build_argument_parser()
defaults = vars(parser.parse_args([]))
unknown = sorted(set(overrides) - set(defaults))
if unknown:
raise ValueError(f"Unknown training args: {', '.join(unknown)}")
defaults.update(overrides)
return argparse.Namespace(**defaults)
def load_model_artifacts(
model_id: str,
*,
trust_remote_code: bool,
max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,
load_in_4bit: bool = True,
fast_inference: bool = False,
prepare_for_inference: bool = False,
):
FastLanguageModel, _ = require_unsloth()
runtime = base.resolve_torch_runtime()
print(f"Loading Unsloth tokenizer+model for {model_id} ...")
model, tokenizer = _call_unsloth_from_pretrained(
FastLanguageModel,
model_name=model_id,
max_seq_length=max_seq_length,
dtype=runtime["dtype"],
load_in_4bit=load_in_4bit,
fast_inference=fast_inference,
trust_remote_code=trust_remote_code,
)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
if prepare_for_inference:
try:
FastLanguageModel.for_inference(model)
except AttributeError:
pass
device = getattr(model, "device", None)
if device is None:
try:
device = next(model.parameters()).device
except StopIteration:
device = runtime["device"]
print(f"Loaded model on device: {device}")
return tokenizer, model
def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
"""Return the OpenEnv-compatible reward callable used by GRPO."""
return base.OpenEnvReward(
reward_backend=args.reward_backend,
base_url=args.base_url,
domain_randomise=args.domain_randomise,
)
def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
"""Build the OpenEnv rollout states that seed GRPO prompts."""
scenario_names = base.selected_scenarios(args.scenario_name)
examples = base.build_prompt_examples(
dataset_episodes=args.dataset_episodes,
rollout_steps=args.rollout_steps,
collection_policy=args.collection_policy,
scenario_names=scenario_names,
seed=args.seed,
domain_randomise=args.domain_randomise,
)
return {
"scenario_names": scenario_names,
"examples": examples,
}
def patch_unsloth_grpo():
"""Patch TRL GRPO to use Unsloth's optimized kernels."""
FastLanguageModel, PatchFastRL = require_unsloth()
PatchFastRL("GRPO", FastLanguageModel)
return FastLanguageModel
def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
"""Apply LoRA adapters in the usual Unsloth configuration style."""
return FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
target_modules=LORA_TARGET_MODULES,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
use_gradient_checkpointing=True,
random_state=args.seed,
)
def build_grpo_config(
args: argparse.Namespace,
runtime: Dict[str, Any],
):
import inspect
base._guard_invalid_torchao_version()
base._guard_partial_vllm_install()
from trl import GRPOConfig
supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
config_kwargs = {
"output_dir": args.output_dir,
"learning_rate": args.learning_rate,
"per_device_train_batch_size": args.per_device_train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"num_generations": args.num_generations,
"max_completion_length": args.max_completion_length,
"num_train_epochs": args.num_train_epochs,
"logging_steps": args.logging_steps,
"save_steps": args.save_steps,
"bf16": runtime["bf16"],
"fp16": runtime["fp16"],
"report_to": "none",
"remove_unused_columns": False,
}
# Only add max_prompt_length if this TRL version supports it; UnslothGRPOTrainer can
# fail when passing it to parent, so we only pass when explicitly supported.
if "max_prompt_length" in supported_params:
config_kwargs["max_prompt_length"] = None # text-only; avoids image_token_id crash
if (
"max_length" in supported_params
and "max_prompt_length" not in supported_params
and "max_completion_length" not in supported_params
):
config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
if skipped:
print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
return GRPOConfig(**filtered_kwargs)
def build_unsloth_grpo_trainer(
*,
model: Any,
tokenizer: Any,
reward_func: Any,
train_dataset: Any,
args: argparse.Namespace,
runtime: Dict[str, Any],
):
base._guard_invalid_torchao_version()
base._guard_partial_vllm_install()
from trl import GRPOTrainer
config = build_grpo_config(args, runtime)
return GRPOTrainer(
model=model,
reward_funcs=reward_func,
args=config,
train_dataset=train_dataset,
processing_class=tokenizer,
)
def generate_action_with_model(
model: Any,
tokenizer: Any,
prompt_or_observation: str | base.ExperimentObservation,
*,
max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,
temperature: float = 0.2,
top_p: float = 0.9,
do_sample: bool = True,
) -> Dict[str, Any]:
import torch
if isinstance(prompt_or_observation, base.ExperimentObservation):
prompt = base.build_training_prompt(prompt_or_observation)
else:
prompt = str(prompt_or_observation)
model_device = getattr(model, "device", None)
if model_device is None:
try:
model_device = next(model.parameters()).device
except StopIteration:
model_device = base.resolve_torch_runtime()["device"]
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(model_device) for key, value in inputs.items()}
prompt_tokens = inputs["input_ids"].shape[1]
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"pad_token_id": tokenizer.pad_token_id,
}
with torch.no_grad():
output_ids = model.generate(**inputs, **generation_kwargs)
new_tokens = output_ids[0][prompt_tokens:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
action = base.parse_action_completion(response_text)
if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
action = base.ensure_conclusion_claims(prompt_or_observation, action)
return {
"prompt": prompt,
"response_text": response_text,
"action": action,
}
def run_training(args: argparse.Namespace) -> Dict[str, Any]:
random.seed(args.seed)
runtime = base.resolve_torch_runtime()
if args.load_model_only:
tokenizer, model = load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
max_seq_length=args.max_seq_length,
load_in_4bit=not args.disable_4bit,
fast_inference=False,
prepare_for_inference=True,
)
device = getattr(model, "device", "unknown")
print(f"Unsloth model ready: {args.model_id}")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model device: {device}")
print(f"Runtime device name: {runtime['device_name']}")
return {
"args": args,
"runtime": runtime,
"tokenizer": tokenizer,
"model": model,
}
prompt_data = prepare_prompt_examples(args)
scenario_names = prompt_data["scenario_names"]
examples = prompt_data["examples"]
env_reward = build_openenv_reward(args)
if args.dry_run:
base.run_dry_run_preview(examples, env_reward, args.output_dir)
return {
"args": args,
"runtime": runtime,
"scenario_names": scenario_names,
"examples": examples,
"reward_fn": env_reward,
}
from datasets import Dataset
FastLanguageModel = patch_unsloth_grpo()
train_dataset = Dataset.from_list(examples)
# 1. Load model with Unsloth quantized loading.
tokenizer, model = load_model_artifacts(
args.model_id,
trust_remote_code=args.trust_remote_code,
max_seq_length=args.max_seq_length,
load_in_4bit=not args.disable_4bit,
fast_inference=False,
)
# 2. Apply LoRA adapters.
model = apply_lora_adapters(FastLanguageModel, model, args)
print(
f"Unsloth training runtime: device={runtime['device']} "
f"name={runtime['device_name']} "
f"dtype={runtime['dtype']} "
f"load_in_4bit={not args.disable_4bit}"
)
print(
"OpenEnv reward: "
f"backend={args.reward_backend} scenarios={len(scenario_names)} "
f"examples={len(examples)}"
)
# 3. Train with GRPO against the OpenEnv reward function.
trainer = build_unsloth_grpo_trainer(
model=model,
tokenizer=tokenizer,
reward_func=env_reward,
train_dataset=train_dataset,
args=args,
runtime=runtime,
)
# Workaround: UnslothGRPOTrainer expects vision token IDs for max_prompt_length
# truncation; text-only models don't have them. Set to None so protected=[].
for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
if not hasattr(trainer, attr):
setattr(trainer, attr, None)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if args.save_merged_16bit:
merged_dir = Path(args.output_dir) / "merged_16bit"
try:
model.save_pretrained_merged(
str(merged_dir),
tokenizer,
save_method="merged_16bit",
)
print(f"Saved merged 16-bit model to {merged_dir}")
except AttributeError:
print("Merged 16-bit export is not available in this Unsloth build; skipping.")
if args.push_to_hub:
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
api.upload_folder(
folder_path=args.output_dir,
repo_id=args.push_to_hub,
repo_type="model",
create_pr=False,
)
print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")
plot_paths = base.save_training_plots(
trainer.state.log_history,
args.output_dir,
metric_key=args.plot_metric_key,
)
print("Saved training plots:")
for plot_name, plot_path in plot_paths.items():
print(f" - {plot_name}: {plot_path}")
return {
"args": args,
"runtime": runtime,
"scenario_names": scenario_names,
"examples": examples,
"reward_fn": env_reward,
"train_dataset": train_dataset,
"tokenizer": tokenizer,
"model": model,
"trainer": trainer,
"plot_paths": plot_paths,
}
def main() -> None:
run_training(parse_args())
if __name__ == "__main__":
main()