Spaces:
Runtime error
Runtime error
File size: 5,923 Bytes
8dc7642 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | <!doctype html>
<html><head><meta charset='utf-8'><title>Minimal training script</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; max-width: 1000px; margin: 40px auto; padding: 0 20px; }
pre { background: #0d1117; color: #c9d1d9; padding: 16px; border-radius: 8px; overflow-x: auto; }
code { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }
</style></head><body>
<h1>Minimal training script</h1>
<p>Key file: <code>scripts/train_grpo_fast.py</code></p>
<pre><code>from __future__ import annotations
import argparse
import os
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1")
os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1")
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from freeciv_env.adapter import prepare_observation
from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle
from freeciv_env.runtime import LiveFreecivSession
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env-url", default="http://127.0.0.1")
parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B")
parser.add_argument("--dataset-size", type=int, default=512)
parser.add_argument("--max-steps", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-generations", type=int, default=4)
parser.add_argument("--episode-horizon", type=int, default=4)
parser.add_argument("--max-prompt-length", type=int, default=768)
parser.add_argument("--max-completion-length", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo")
parser.add_argument("--save-steps", type=int, default=50)
return parser.parse_args()
def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset:
rows = {"prompt": [], "best_index": []}
while len(rows["prompt"]) < dataset_size:
session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120)
try:
snapshot = session.reset()
for turn_index in range(episode_horizon):
observation = prepare_observation(
snapshot,
reward=0.0,
done=False,
status="running",
).observation
best_index = oracle_action_index(observation.legal_actions)
rows["prompt"].append(build_turn_prompt(observation))
rows["best_index"].append(best_index)
if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon:
break
snapshot = session.end_turn()
finally:
session.close()
return Dataset.from_dict(rows)
def load_model(model_id: str, max_seq_length: int, lora_rank: int):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
load_in_4bit=False,
load_in_16bit=True,
full_finetuning=False,
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_rank * 2,
lora_dropout=0,
bias="none",
use_gradient_checkpointing=False,
random_state=3407,
max_seq_length=max_seq_length,
)
return model, tokenizer
def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset:
def format_row(row):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["prompt"]},
]
return {
"prompt": tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
}
return dataset.map(format_row)
def main() -> None:
args = parse_args()
max_seq_length = args.max_prompt_length + args.max_completion_length
dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon)
model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank)
dataset = apply_chat_template(dataset, tokenizer)
training_args = GRPOConfig(
learning_rate=args.learning_rate,
weight_decay=0.01,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
optim="adamw_torch_fused",
logging_steps=1,
log_completions=False,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
max_steps=args.max_steps,
save_steps=args.save_steps,
max_grad_norm=0.3,
bf16=True,
report_to="none",
beta=0.0,
loss_type="dr_grpo",</code></pre>
</body></html> |