Spaces:
Runtime error
Runtime error
| <html><head><meta charset='utf-8'><title>Minimal training script</title> | |
| <style> | |
| body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; max-width: 1000px; margin: 40px auto; padding: 0 20px; } | |
| pre { background: #0d1117; color: #c9d1d9; padding: 16px; border-radius: 8px; overflow-x: auto; } | |
| code { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; } | |
| </style></head><body> | |
| <h1>Minimal training script</h1> | |
| <p>Key file: <code>scripts/train_grpo_fast.py</code></p> | |
| <pre><code>from __future__ import annotations | |
| import argparse | |
| import os | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1") | |
| os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1") | |
| from unsloth import FastLanguageModel | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from freeciv_env.adapter import prepare_observation | |
| from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle | |
| from freeciv_env.runtime import LiveFreecivSession | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--env-url", default="http://127.0.0.1") | |
| parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B") | |
| parser.add_argument("--dataset-size", type=int, default=512) | |
| parser.add_argument("--max-steps", type=int, default=50) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--episode-horizon", type=int, default=4) | |
| parser.add_argument("--max-prompt-length", type=int, default=768) | |
| parser.add_argument("--max-completion-length", type=int, default=8) | |
| parser.add_argument("--learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--lora-rank", type=int, default=16) | |
| parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo") | |
| parser.add_argument("--save-steps", type=int, default=50) | |
| return parser.parse_args() | |
| def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset: | |
| rows = {"prompt": [], "best_index": []} | |
| while len(rows["prompt"]) < dataset_size: | |
| session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120) | |
| try: | |
| snapshot = session.reset() | |
| for turn_index in range(episode_horizon): | |
| observation = prepare_observation( | |
| snapshot, | |
| reward=0.0, | |
| done=False, | |
| status="running", | |
| ).observation | |
| best_index = oracle_action_index(observation.legal_actions) | |
| rows["prompt"].append(build_turn_prompt(observation)) | |
| rows["best_index"].append(best_index) | |
| if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon: | |
| break | |
| snapshot = session.end_turn() | |
| finally: | |
| session.close() | |
| return Dataset.from_dict(rows) | |
| def load_model(model_id: str, max_seq_length: int, lora_rank: int): | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_id, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=False, | |
| load_in_16bit=True, | |
| full_finetuning=False, | |
| fast_inference=False, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=lora_rank, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=lora_rank * 2, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing=False, | |
| random_state=3407, | |
| max_seq_length=max_seq_length, | |
| ) | |
| return model, tokenizer | |
| def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset: | |
| def format_row(row): | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": row["prompt"]}, | |
| ] | |
| return { | |
| "prompt": tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| } | |
| return dataset.map(format_row) | |
| def main() -> None: | |
| args = parse_args() | |
| max_seq_length = args.max_prompt_length + args.max_completion_length | |
| dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon) | |
| model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank) | |
| dataset = apply_chat_template(dataset, tokenizer) | |
| training_args = GRPOConfig( | |
| learning_rate=args.learning_rate, | |
| weight_decay=0.01, | |
| warmup_ratio=0.05, | |
| lr_scheduler_type="cosine", | |
| optim="adamw_torch_fused", | |
| logging_steps=1, | |
| log_completions=False, | |
| per_device_train_batch_size=args.batch_size, | |
| gradient_accumulation_steps=1, | |
| num_generations=args.num_generations, | |
| max_prompt_length=args.max_prompt_length, | |
| max_completion_length=args.max_completion_length, | |
| max_steps=args.max_steps, | |
| save_steps=args.save_steps, | |
| max_grad_norm=0.3, | |
| bf16=True, | |
| report_to="none", | |
| beta=0.0, | |
| loss_type="dr_grpo",</code></pre> | |
| </body></html> |