File size: 5,341 Bytes
8dc7642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import annotations

import argparse
import os

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("UNSLOTH_RETURN_LOGITS", "1")
os.environ.setdefault("UNSLOTH_DISABLE_AUTO_UPDATES", "1")

from unsloth import FastLanguageModel
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer

from freeciv_env.adapter import prepare_observation
from freeciv_env.grpo import SYSTEM_PROMPT, build_turn_prompt, oracle_action_index, reward_from_oracle
from freeciv_env.runtime import LiveFreecivSession


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-url", default="http://127.0.0.1")
    parser.add_argument("--model-id", default="Qwen/Qwen3.5-0.8B")
    parser.add_argument("--dataset-size", type=int, default=512)
    parser.add_argument("--max-steps", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-generations", type=int, default=4)
    parser.add_argument("--episode-horizon", type=int, default=4)
    parser.add_argument("--max-prompt-length", type=int, default=768)
    parser.add_argument("--max-completion-length", type=int, default=8)
    parser.add_argument("--learning-rate", type=float, default=5e-6)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--output-dir", default="outputs/qwen35_08b_grpo")
    parser.add_argument("--save-steps", type=int, default=50)
    return parser.parse_args()



def collect_dataset(env_url: str, dataset_size: int, episode_horizon: int) -> Dataset:
    rows = {"prompt": [], "best_index": []}
    while len(rows["prompt"]) < dataset_size:
        session = LiveFreecivSession(base_url=env_url, turn_timeout_s=120)
        try:
            snapshot = session.reset()
            for turn_index in range(episode_horizon):
                observation = prepare_observation(
                    snapshot,
                    reward=0.0,
                    done=False,
                    status="running",
                ).observation
                best_index = oracle_action_index(observation.legal_actions)
                rows["prompt"].append(build_turn_prompt(observation))
                rows["best_index"].append(best_index)
                if len(rows["prompt"]) >= dataset_size or turn_index + 1 >= episode_horizon:
                    break
                snapshot = session.end_turn()
        finally:
            session.close()
    return Dataset.from_dict(rows)



def load_model(model_id: str, max_seq_length: int, lora_rank: int):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_id,
        max_seq_length=max_seq_length,
        load_in_4bit=False,
        load_in_16bit=True,
        full_finetuning=False,
        fast_inference=False,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_rank,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=lora_rank * 2,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing=False,
        random_state=3407,
        max_seq_length=max_seq_length,
    )
    return model, tokenizer



def apply_chat_template(dataset: Dataset, tokenizer) -> Dataset:
    def format_row(row):
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": row["prompt"]},
        ]
        return {
            "prompt": tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
                enable_thinking=False,
            )
        }

    return dataset.map(format_row)



def main() -> None:
    args = parse_args()
    max_seq_length = args.max_prompt_length + args.max_completion_length
    dataset = collect_dataset(args.env_url, args.dataset_size, args.episode_horizon)
    model, tokenizer = load_model(args.model_id, max_seq_length, args.lora_rank)
    dataset = apply_chat_template(dataset, tokenizer)

    training_args = GRPOConfig(
        learning_rate=args.learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        optim="adamw_torch_fused",
        logging_steps=1,
        log_completions=False,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=1,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,
        max_steps=args.max_steps,
        save_steps=args.save_steps,
        max_grad_norm=0.3,
        bf16=True,
        report_to="none",
        beta=0.0,
        loss_type="dr_grpo",
        temperature=0.7,
        top_p=0.8,
        top_k=20,
        output_dir=args.output_dir,
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_from_oracle,
        train_dataset=dataset,
        args=training_args,
    )
    trainer.train()
    model.save_pretrained(f"{args.output_dir}/lora")
    tokenizer.save_pretrained(f"{args.output_dir}/lora")


if __name__ == "__main__":
    main()