from __future__ import annotations import argparse import os os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1") os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1") from unsloth import FastLanguageModel from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from freeciv_env.adapter import prepare_observation from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle from freeciv_env.runtime import LiveFreecivSession def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--env-url", default="http://127.0.0.1") parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B") parser.add_argument("--dataset-size", type=int, default=512) parser.add_argument("--max-steps", type=int, default=50) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--num-generations", type=int, default=4) parser.add_argument("--episode-horizon", type=int, default=4) parser.add_argument("--max-prompt-length", type=int, default=768) parser.add_argument("--max-completion-length", type=int, default=8) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--lora-rank", type=int, default=16) parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo") parser.add_argument("--save-steps", type=int, default=50) return parser.parse_args() def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset: rows = {"prompt": [], "best_index": []} while len(rows["prompt"]) < dataset_size: session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120) try: snapshot = session.reset() for turn_index in range(episode_horizon): observation = prepare_observation( snapshot, reward=0.0, done=False, status="running", ).observation best_index = oracle_action_index(observation.legal_actions) rows["prompt"].append(build_turn_prompt(observation)) rows["best_index"].append(best_index) if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon: break snapshot = session.end_turn() finally: session.close() return Dataset.from_dict(rows) def load_model(model_id: str, max_seq_length: int, lora_rank: int): model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_id, max_seq_length=max_seq_length, load_in_4bit=False, load_in_16bit=True, full_finetuning=False, fast_inference=False, ) model = FastLanguageModel.get_peft_model( model, r=lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_rank * 2, lora_dropout=0, bias="none", use_gradient_checkpointing=False, random_state=3407, max_seq_length=max_seq_length, ) return model, tokenizer def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset: def format_row(row): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": row["prompt"]}, ] return { "prompt": tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) } return dataset.map(format_row) def main() -> None: args = parse_args() max_seq_length = args.max_prompt_length + args.max_completion_length dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon) model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank) dataset = apply_chat_template(dataset, tokenizer) training_args = GRPOConfig( learning_rate=args.learning_rate, weight_decay=0.01, warmup_ratio=0.05, lr_scheduler_type="cosine", optim="adamw_torch_fused", logging_steps=1, log_completions=False, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=1, num_generations=args.num_generations, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, max_steps=args.max_steps, save_steps=args.save_steps, max_grad_norm=0.3, bf16=True, report_to="none", beta=0.0, loss_type="dr_grpo", temperature=0.7, top_p=0.8, top_k=20, output_dir=args.output_dir, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=reward_from_oracle, train_dataset=dataset, args=training_args, ) trainer.train() model.save_pretrained(f"{args.output_dir}/lora") tokenizer.save_pretrained(f"{args.output_dir}/lora") if __name__ == "__main__": main()