openenv_hack / scripts /train_grpo_fast.py
thomasm6m6's picture
Initial Freeciv OpenEnv Space
8dc7642 verified
from __future__ import annotations
import argparse
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1")
os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1")
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from freeciv_env.adapter import prepare_observation
from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle
from freeciv_env.runtime import LiveFreecivSession
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env-url", default="http://127.0.0.1")
parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B")
parser.add_argument("--dataset-size", type=int, default=512)
parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--episode-horizon", type=int, default=4)
parser.add_argument("--max-prompt-length", type=int, default=768)
parser.add_argument("--max-completion-length", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo")
parser.add_argument("--save-steps", type=int, default=50)
return parser.parse_args()
def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset:
rows = {"prompt": [], "best_index": []}
while len(rows["prompt"]) < dataset_size:
session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120)
try:
snapshot = session.reset()
for turn_index in range(episode_horizon):
observation = prepare_observation(
snapshot,
reward=0.0,
done=False,
status="running",
).observation
best_index = oracle_action_index(observation.legal_actions)
rows["prompt"].append(build_turn_prompt(observation))
rows["best_index"].append(best_index)
if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon:
break
snapshot = session.end_turn()
finally:
session.close()
return Dataset.from_dict(rows)
def load_model(model_id: str, max_seq_length: int, lora_rank: int):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
load_in_4bit=False,
load_in_16bit=True,
full_finetuning=False,
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_rank * 2,
lora_dropout=0,
bias="none",
use_gradient_checkpointing=False,
random_state=3407,
max_seq_length=max_seq_length,
)
return model, tokenizer
def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset:
def format_row(row):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["prompt"]},
]
return {
"prompt": tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
}
return dataset.map(format_row)
def main() -> None:
args = parse_args()
max_seq_length = args.max_prompt_length + args.max_completion_length
dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon)
model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank)
dataset = apply_chat_template(dataset, tokenizer)
training_args = GRPOConfig(
learning_rate=args.learning_rate,
weight_decay=0.01,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
optim="adamw_torch_fused",
logging_steps=1,
log_completions=False,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
max_steps=args.max_steps,
save_steps=args.save_steps,
max_grad_norm=0.3,
bf16=True,
report_to="none",
beta=0.0,
loss_type="dr_grpo",
temperature=0.7,
top_p=0.8,
top_k=20,
output_dir=args.output_dir,
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=reward_from_oracle,
train_dataset=dataset,
args=training_args,
)
trainer.train()
model.save_pretrained(f"{args.output_dir}/lora")
tokenizer.save_pretrained(f"{args.output_dir}/lora")
if __name__ == "__main__":
main()