File size: 16,205 Bytes
db03c40
 
 
 
 
 
 
 
ad39f2a
 
 
db03c40
 
 
 
 
 
 
 
 
ad39f2a
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
 
 
 
db03c40
 
ad39f2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
ad39f2a
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad39f2a
 
 
 
 
db03c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""Train and run quantized self-driving lab models with Unsloth.



This keeps the same OpenEnv prompt + reward wiring as `training_script.py`,

but arranges the Unsloth path in the more typical pattern:

1. patch GRPO support

2. load a quantized model

3. apply LoRA adapters

4. train with an explicit OpenEnv reward function



NOTE: Unsloth must be imported before trl, transformers, peft. Import this

module before training_script.

"""

from __future__ import annotations

import argparse
import random
from pathlib import Path
from typing import Any, Dict, Optional, Sequence

# Unsloth must be imported before trl/transformers/peft for optimizations.
import unsloth  # noqa: F401

import training_script as base

DEFAULT_OUTPUT_DIR = "training/grpo-unsloth-output"
DEFAULT_MAX_SEQ_LENGTH = 2048
DEFAULT_LORA_R = 16
DEFAULT_LORA_ALPHA = 16
DEFAULT_LORA_DROPOUT = 0.0
LORA_TARGET_MODULES = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]


def require_unsloth():
    try:
        from unsloth import FastLanguageModel, PatchFastRL
    except ImportError as exc:
        msg = str(exc)
        if "vllm.lora" in msg or "vllm" in msg.lower():
            raise RuntimeError(
                f"Unsloth failed: {exc}. "
                "unsloth_zoo expects vllm.lora.models. Install a compatible vllm:\n"
                "  pip install 'vllm==0.8.2'   # requires torch 2.6\n"
                "  pip install 'vllm==0.7.3'    # alternative\n"
                "If torch>=2.10 conflicts, use a separate env with torch 2.6–2.8."
            ) from exc
        if "unsloth" in msg.lower():
            raise RuntimeError(
                "Unsloth is not installed. Run `uv sync` or `pip install unsloth`."
            ) from exc
        raise RuntimeError(f"Failed to import Unsloth: {exc}") from exc
    return FastLanguageModel, PatchFastRL


def _call_unsloth_from_pretrained(FastLanguageModel, **kwargs: Any):
    for optional_key in ("fast_inference", "trust_remote_code"):
        try:
            return FastLanguageModel.from_pretrained(**kwargs)
        except TypeError as exc:
            if optional_key in kwargs and optional_key in str(exc):
                kwargs = dict(kwargs)
                kwargs.pop(optional_key, None)
                continue
            raise
    return FastLanguageModel.from_pretrained(**kwargs)


def build_argument_parser() -> argparse.ArgumentParser:
    parser = base.build_argument_parser()
    parser.description = (
        "Train a GRPO policy with Unsloth quantized loading for faster H100 runs."
    )
    parser.set_defaults(output_dir=DEFAULT_OUTPUT_DIR)
    parser.add_argument(
        "--max-seq-length",
        type=int,
        default=DEFAULT_MAX_SEQ_LENGTH,
        help="Context length passed to Unsloth model loading.",
    )
    parser.add_argument(
        "--disable-4bit",
        action="store_true",
        help="Disable 4-bit quantized loading and use the wider base weights.",
    )
    parser.add_argument(
        "--lora-r",
        type=int,
        default=DEFAULT_LORA_R,
        help="LoRA rank used for the quantized GRPO policy.",
    )
    parser.add_argument(
        "--lora-alpha",
        type=int,
        default=DEFAULT_LORA_ALPHA,
        help="LoRA alpha used for the quantized GRPO policy.",
    )
    parser.add_argument(
        "--lora-dropout",
        type=float,
        default=DEFAULT_LORA_DROPOUT,
        help="LoRA dropout used for the quantized GRPO policy.",
    )
    parser.add_argument(
        "--save-merged-16bit",
        action="store_true",
        help="Also export a merged 16-bit model after training if supported.",
    )
    return parser


def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    return build_argument_parser().parse_args(argv)


def make_training_args(**overrides: Any) -> argparse.Namespace:
    parser = build_argument_parser()
    defaults = vars(parser.parse_args([]))
    unknown = sorted(set(overrides) - set(defaults))
    if unknown:
        raise ValueError(f"Unknown training args: {', '.join(unknown)}")
    defaults.update(overrides)
    return argparse.Namespace(**defaults)


def load_model_artifacts(

    model_id: str,

    *,

    trust_remote_code: bool,

    max_seq_length: int = DEFAULT_MAX_SEQ_LENGTH,

    load_in_4bit: bool = True,

    fast_inference: bool = False,

    prepare_for_inference: bool = False,

):
    FastLanguageModel, _ = require_unsloth()
    runtime = base.resolve_torch_runtime()

    print(f"Loading Unsloth tokenizer+model for {model_id} ...")
    model, tokenizer = _call_unsloth_from_pretrained(
        FastLanguageModel,
        model_name=model_id,
        max_seq_length=max_seq_length,
        dtype=runtime["dtype"],
        load_in_4bit=load_in_4bit,
        fast_inference=fast_inference,
        trust_remote_code=trust_remote_code,
    )
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    if prepare_for_inference:
        try:
            FastLanguageModel.for_inference(model)
        except AttributeError:
            pass

    device = getattr(model, "device", None)
    if device is None:
        try:
            device = next(model.parameters()).device
        except StopIteration:
            device = runtime["device"]
    print(f"Loaded model on device: {device}")
    return tokenizer, model


def build_openenv_reward(args: argparse.Namespace) -> base.OpenEnvReward:
    """Return the OpenEnv-compatible reward callable used by GRPO."""
    return base.OpenEnvReward(
        reward_backend=args.reward_backend,
        base_url=args.base_url,
        domain_randomise=args.domain_randomise,
    )


def prepare_prompt_examples(args: argparse.Namespace) -> Dict[str, Any]:
    """Build the OpenEnv rollout states that seed GRPO prompts."""
    scenario_names = base.selected_scenarios(args.scenario_name)
    examples = base.build_prompt_examples(
        dataset_episodes=args.dataset_episodes,
        rollout_steps=args.rollout_steps,
        collection_policy=args.collection_policy,
        scenario_names=scenario_names,
        seed=args.seed,
        domain_randomise=args.domain_randomise,
    )
    return {
        "scenario_names": scenario_names,
        "examples": examples,
    }


def patch_unsloth_grpo():
    """Patch TRL GRPO to use Unsloth's optimized kernels."""
    FastLanguageModel, PatchFastRL = require_unsloth()
    PatchFastRL("GRPO", FastLanguageModel)
    return FastLanguageModel


def apply_lora_adapters(FastLanguageModel, model: Any, args: argparse.Namespace) -> Any:
    """Apply LoRA adapters in the usual Unsloth configuration style."""
    return FastLanguageModel.get_peft_model(
        model,
        r=args.lora_r,
        target_modules=LORA_TARGET_MODULES,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        use_gradient_checkpointing=True,
        random_state=args.seed,
    )


def build_grpo_config(

    args: argparse.Namespace,

    runtime: Dict[str, Any],

):
    import inspect

    base._guard_invalid_torchao_version()
    base._guard_partial_vllm_install()
    from trl import GRPOConfig

    supported_params = set(inspect.signature(GRPOConfig.__init__).parameters)
    config_kwargs = {
        "output_dir": args.output_dir,
        "learning_rate": args.learning_rate,
        "per_device_train_batch_size": args.per_device_train_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,
        "num_generations": args.num_generations,
        "max_completion_length": args.max_completion_length,
        "num_train_epochs": args.num_train_epochs,
        "logging_steps": args.logging_steps,
        "save_steps": args.save_steps,
        "bf16": runtime["bf16"],
        "fp16": runtime["fp16"],
        "report_to": "none",
        "remove_unused_columns": False,
    }
    # Only add max_prompt_length if this TRL version supports it; UnslothGRPOTrainer can
    # fail when passing it to parent, so we only pass when explicitly supported.
    if "max_prompt_length" in supported_params:
        config_kwargs["max_prompt_length"] = None  # text-only; avoids image_token_id crash
    if (
        "max_length" in supported_params
        and "max_prompt_length" not in supported_params
        and "max_completion_length" not in supported_params
    ):
        config_kwargs["max_length"] = getattr(args, "max_prompt_length", 1024) + args.max_completion_length
    filtered_kwargs = {k: v for k, v in config_kwargs.items() if k in supported_params}
    skipped = sorted(set(config_kwargs) - set(filtered_kwargs))
    if skipped:
        print(f"GRPOConfig compatibility: skipping unsupported fields {', '.join(skipped)}")
    return GRPOConfig(**filtered_kwargs)


def build_unsloth_grpo_trainer(

    *,

    model: Any,

    tokenizer: Any,

    reward_func: Any,

    train_dataset: Any,

    args: argparse.Namespace,

    runtime: Dict[str, Any],

):
    base._guard_invalid_torchao_version()
    base._guard_partial_vllm_install()
    from trl import GRPOTrainer

    config = build_grpo_config(args, runtime)
    return GRPOTrainer(
        model=model,
        reward_funcs=reward_func,
        args=config,
        train_dataset=train_dataset,
        processing_class=tokenizer,
    )


def generate_action_with_model(

    model: Any,

    tokenizer: Any,

    prompt_or_observation: str | base.ExperimentObservation,

    *,

    max_new_tokens: int = base.DEFAULT_COMPLETION_TOKEN_BUDGET,

    temperature: float = 0.2,

    top_p: float = 0.9,

    do_sample: bool = True,

) -> Dict[str, Any]:
    import torch

    if isinstance(prompt_or_observation, base.ExperimentObservation):
        prompt = base.build_training_prompt(prompt_or_observation)
    else:
        prompt = str(prompt_or_observation)

    model_device = getattr(model, "device", None)
    if model_device is None:
        try:
            model_device = next(model.parameters()).device
        except StopIteration:
            model_device = base.resolve_torch_runtime()["device"]

    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {key: value.to(model_device) for key, value in inputs.items()}
    prompt_tokens = inputs["input_ids"].shape[1]

    generation_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "temperature": temperature,
        "top_p": top_p,
        "pad_token_id": tokenizer.pad_token_id,
    }
    with torch.no_grad():
        output_ids = model.generate(**inputs, **generation_kwargs)

    new_tokens = output_ids[0][prompt_tokens:]
    response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
    action = base.parse_action_completion(response_text)
    if action is not None and isinstance(prompt_or_observation, base.ExperimentObservation):
        action = base.ensure_conclusion_claims(prompt_or_observation, action)
    return {
        "prompt": prompt,
        "response_text": response_text,
        "action": action,
    }


def run_training(args: argparse.Namespace) -> Dict[str, Any]:
    random.seed(args.seed)
    runtime = base.resolve_torch_runtime()

    if args.load_model_only:
        tokenizer, model = load_model_artifacts(
            args.model_id,
            trust_remote_code=args.trust_remote_code,
            max_seq_length=args.max_seq_length,
            load_in_4bit=not args.disable_4bit,
            fast_inference=False,
            prepare_for_inference=True,
        )
        device = getattr(model, "device", "unknown")
        print(f"Unsloth model ready: {args.model_id}")
        print(f"Tokenizer vocab size: {len(tokenizer)}")
        print(f"Model device: {device}")
        print(f"Runtime device name: {runtime['device_name']}")
        return {
            "args": args,
            "runtime": runtime,
            "tokenizer": tokenizer,
            "model": model,
        }

    prompt_data = prepare_prompt_examples(args)
    scenario_names = prompt_data["scenario_names"]
    examples = prompt_data["examples"]
    env_reward = build_openenv_reward(args)

    if args.dry_run:
        base.run_dry_run_preview(examples, env_reward, args.output_dir)
        return {
            "args": args,
            "runtime": runtime,
            "scenario_names": scenario_names,
            "examples": examples,
            "reward_fn": env_reward,
        }

    from datasets import Dataset

    FastLanguageModel = patch_unsloth_grpo()
    train_dataset = Dataset.from_list(examples)

    # 1. Load model with Unsloth quantized loading.
    tokenizer, model = load_model_artifacts(
        args.model_id,
        trust_remote_code=args.trust_remote_code,
        max_seq_length=args.max_seq_length,
        load_in_4bit=not args.disable_4bit,
        fast_inference=False,
    )
    # 2. Apply LoRA adapters.
    model = apply_lora_adapters(FastLanguageModel, model, args)

    print(
        f"Unsloth training runtime: device={runtime['device']} "
        f"name={runtime['device_name']} "
        f"dtype={runtime['dtype']} "
        f"load_in_4bit={not args.disable_4bit}"
    )
    print(
        "OpenEnv reward: "
        f"backend={args.reward_backend} scenarios={len(scenario_names)} "
        f"examples={len(examples)}"
    )

    # 3. Train with GRPO against the OpenEnv reward function.
    trainer = build_unsloth_grpo_trainer(
        model=model,
        tokenizer=tokenizer,
        reward_func=env_reward,
        train_dataset=train_dataset,
        args=args,
        runtime=runtime,
    )
    # Workaround: UnslothGRPOTrainer expects vision token IDs for max_prompt_length
    # truncation; text-only models don't have them. Set to None so protected=[].
    for attr in ("image_token_id", "vision_start_token_id", "vision_end_token_id"):
        if not hasattr(trainer, attr):
            setattr(trainer, attr, None)
    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    if args.save_merged_16bit:
        merged_dir = Path(args.output_dir) / "merged_16bit"
        try:
            model.save_pretrained_merged(
                str(merged_dir),
                tokenizer,
                save_method="merged_16bit",
            )
            print(f"Saved merged 16-bit model to {merged_dir}")
        except AttributeError:
            print("Merged 16-bit export is not available in this Unsloth build; skipping.")

    if args.push_to_hub:
        from huggingface_hub import HfApi

        api = HfApi()
        api.create_repo(repo_id=args.push_to_hub, repo_type="model", exist_ok=True)
        print(f"Pushing model to HuggingFace Hub: {args.push_to_hub}")
        api.upload_folder(
            folder_path=args.output_dir,
            repo_id=args.push_to_hub,
            repo_type="model",
            create_pr=False,
        )
        print(f"Model pushed to https://huggingface.co/{args.push_to_hub}")

    plot_paths = base.save_training_plots(
        trainer.state.log_history,
        args.output_dir,
        metric_key=args.plot_metric_key,
    )
    print("Saved training plots:")
    for plot_name, plot_path in plot_paths.items():
        print(f"  - {plot_name}: {plot_path}")

    return {
        "args": args,
        "runtime": runtime,
        "scenario_names": scenario_names,
        "examples": examples,
        "reward_fn": env_reward,
        "train_dataset": train_dataset,
        "tokenizer": tokenizer,
        "model": model,
        "trainer": trainer,
        "plot_paths": plot_paths,
    }


def main() -> None:
    run_training(parse_args())


if __name__ == "__main__":
    main()