openenv_hack / pres /training_script.html
thomasm6m6's picture
Initial Freeciv OpenEnv Space
8dc7642 verified
<!doctype html>
<html><head><meta charset='utf-8'><title>Minimal training script</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; max-width: 1000px; margin: 40px auto; padding: 0 20px; }
pre { background: #0d1117; color: #c9d1d9; padding: 16px; border-radius: 8px; overflow-x: auto; }
code { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; }
</style></head><body>
<h1>Minimal training script</h1>
<p>Key file: <code>scripts/train_grpo_fast.py</code></p>
<pre><code>from __future__ import annotations
import argparse
import os
os.environ.setdefault(&quot;TOKENIZERS_PARALLELISM&quot;, &quot;false&quot;)
os.environ.setdefault(&quot;UNSLOTH_RETURN_LOGITS&quot;, &quot;1&quot;)
os.environ.setdefault(&quot;UNSLOTH_DISABLE_AUTO_UPDATES&quot;, &quot;1&quot;)
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from freeciv_env.adapter import prepare_observation
from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle
from freeciv_env.runtime import LiveFreecivSession
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(&quot;--env-url&quot;, default=&quot;http://127.0.0.1&quot;)
parser.add_argument(&quot;--model-id&quot;, default=&quot;Qwen/Qwen3.5-0.8B&quot;)
parser.add_argument(&quot;--dataset-size&quot;, type=int, default=512)
parser.add_argument(&quot;--max-steps&quot;, type=int, default=50)
parser.add_argument(&quot;--batch-size&quot;, type=int, default=16)
parser.add_argument(&quot;--num-generations&quot;, type=int, default=4)
parser.add_argument(&quot;--episode-horizon&quot;, type=int, default=4)
parser.add_argument(&quot;--max-prompt-length&quot;, type=int, default=768)
parser.add_argument(&quot;--max-completion-length&quot;, type=int, default=8)
parser.add_argument(&quot;--learning-rate&quot;, type=float, default=5e-6)
parser.add_argument(&quot;--lora-rank&quot;, type=int, default=16)
parser.add_argument(&quot;--output-dir&quot;, default=&quot;outputs/qwen35_08b_grpo&quot;)
parser.add_argument(&quot;--save-steps&quot;, type=int, default=50)
return parser.parse_args()
def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -&gt; Dataset:
rows = {&quot;prompt&quot;: [], &quot;best_index&quot;: []}
while len(rows[&quot;prompt&quot;]) &lt; dataset_size:
session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120)
try:
snapshot = session.reset()
for turn_index in range(episode_horizon):
observation = prepare_observation(
snapshot,
reward=0.0,
done=False,
status=&quot;running&quot;,
).observation
best_index = oracle_action_index(observation.legal_actions)
rows[&quot;prompt&quot;].append(build_turn_prompt(observation))
rows[&quot;best_index&quot;].append(best_index)
if len(rows[&quot;prompt&quot;]) &gt;= dataset_size or turn_index + 1 &gt;= episode_horizon:
break
snapshot = session.end_turn()
finally:
session.close()
return Dataset.from_dict(rows)
def load_model(model_id: str, max_seq_length: int, lora_rank: int):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
load_in_4bit=False,
load_in_16bit=True,
full_finetuning=False,
fast_inference=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=[
&quot;q_proj&quot;,
&quot;k_proj&quot;,
&quot;v_proj&quot;,
&quot;o_proj&quot;,
&quot;gate_proj&quot;,
&quot;up_proj&quot;,
&quot;down_proj&quot;,
],
lora_alpha=lora_rank * 2,
lora_dropout=0,
bias=&quot;none&quot;,
use_gradient_checkpointing=False,
random_state=3407,
max_seq_length=max_seq_length,
)
return model, tokenizer
def apply_chat_template(dataset: Dataset, tokenizer) -&gt; Dataset:
def format_row(row):
messages = [
{&quot;role&quot;: &quot;system&quot;, &quot;content&quot;: SYSTEM_PROMPT},
{&quot;role&quot;: &quot;user&quot;, &quot;content&quot;: row[&quot;prompt&quot;]},
]
return {
&quot;prompt&quot;: tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
}
return dataset.map(format_row)
def main() -&gt; None:
args = parse_args()
max_seq_length = args.max_prompt_length + args.max_completion_length
dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon)
model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank)
dataset = apply_chat_template(dataset, tokenizer)
training_args = GRPOConfig(
learning_rate=args.learning_rate,
weight_decay=0.01,
warmup_ratio=0.05,
lr_scheduler_type=&quot;cosine&quot;,
optim=&quot;adamw_torch_fused&quot;,
logging_steps=1,
log_completions=False,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=1,
num_generations=args.num_generations,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
max_steps=args.max_steps,
save_steps=args.save_steps,
max_grad_norm=0.3,
bf16=True,
report_to=&quot;none&quot;,
beta=0.0,
loss_type=&quot;dr_grpo&quot;,</code></pre>
</body></html>