Minimal training script

Key file: scripts/train_grpo_fast.py

from __future__ import annotations

import argparse
import os

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1")
os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1")

from unsloth import FastLanguageModel
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer

from freeciv_env.adapter import prepare_observation
from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle
from freeciv_env.runtime import LiveFreecivSession


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-url", default="http://127.0.0.1")
    parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B")
    parser.add_argument("--dataset-size", type=int, default=512)
    parser.add_argument("--max-steps", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-generations", type=int, default=4)
    parser.add_argument("--episode-horizon", type=int, default=4)
    parser.add_argument("--max-prompt-length", type=int, default=768)
    parser.add_argument("--max-completion-length", type=int, default=8)
    parser.add_argument("--learning-rate", type=float, default=5e-6)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo")
    parser.add_argument("--save-steps", type=int, default=50)
    return parser.parse_args()



def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset:
    rows = {"prompt": [], "best_index": []}
    while len(rows["prompt"]) < dataset_size:
        session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120)
        try:
            snapshot = session.reset()
            for turn_index in range(episode_horizon):
                observation = prepare_observation(
                    snapshot,
                    reward=0.0,
                    done=False,
                    status="running",
                ).observation
                best_index = oracle_action_index(observation.legal_actions)
                rows["prompt"].append(build_turn_prompt(observation))
                rows["best_index"].append(best_index)
                if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon:
                    break
                snapshot = session.end_turn()
        finally:
            session.close()
    return Dataset.from_dict(rows)



def load_model(model_id: str, max_seq_length: int, lora_rank: int):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_id,
        max_seq_length=max_seq_length,
        load_in_4bit=False,
        load_in_16bit=True,
        full_finetuning=False,
        fast_inference=False,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_rank,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=lora_rank * 2,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing=False,
        random_state=3407,
        max_seq_length=max_seq_length,
    )
    return model, tokenizer



def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset:
    def format_row(row):
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": row["prompt"]},
        ]
        return {
            "prompt": tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False,
            )
        }

    return dataset.map(format_row)



def main() -> None:
    args = parse_args()
    max_seq_length = args.max_prompt_length + args.max_completion_length
    dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon)
    model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank)
    dataset = apply_chat_template(dataset, tokenizer)

    training_args = GRPOConfig(
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        optim="adamw_torch_fused",
        logging_steps=1,
        log_completions=False,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=1,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        max_steps=args.max_steps,
        save_steps=args.save_steps,
        max_grad_norm=0.3,
        bf16=True,
        report_to="none",
        beta=0.0,
        loss_type="dr_grpo",