File size: 3,445 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# method name
method: latmem

model: 
  # base llm
  reasoner_model_name: Qwen/Qwen2.5-1.5B-Instruct
  # load trained model
  load_model_path: null
  
  # max prompt/inference augmentation num
  max_prompt_aug_num: null
  max_inference_aug_num: 5

  # processor configs
  weaver:
    weaver_model_name: Qwen/Qwen2.5-1.5B-Instruct
    prompt_latents_len: 8
    inference_latents_len: 8

    use_peft: True
    peft_config:
      r: 16
      lora_alpha: 32
      target_modules: ["q_proj", "v_proj"]
      lora_dropout: 0.1
      bias: "none"
      task_type: "CAUSAL_LM"

  
  # trigger model configs
  trigger:
    trigger_model_name: Qwen/Qwen2.5-0.5B-Instruct
  
    use_peft: True
    peft_config:
      r: 16
      lora_alpha: 32
      target_modules: ["q_proj", "v_proj"]
      lora_dropout: 0.1
      bias: "none"
      task_type: "CAUSAL_LM"

datasets: 
  kodcode:
    mode: sft
    sft:
      cache_path: dataset/kodcode_sft
      train_ratio: 0.7
      valid_ratio: 0.1
      test_ratio: 0.2
    grpo: 
      cache_path: dataset/kodcode_grpo
      train_ratio: 0.7
      valid_ratio: 0.1
      test_ratio: 0.2

# training/evaluation configs
run: 

  seed: 42
  use_wandb: True
  
  # route
  mode: train
  train_weaver: True
  train_weaver_method: sft    # sft or grpo
  train_trigger: False
  train_trigger_method: grpo 

  # processor training configs
  weaver:
    
    # sft configs
    sft:
      max_epochs: 2
      batch_size: 4
      grad_accum_steps: 1
      
      # optimizer configs
      optim: adamw_torch
      schedular: cosine
      warmup_ratio: 0.1
      lr: 1e-5
      
      # logging
      logging_strategy: steps
      logging_steps: 1
      eval_strategy: epoch
      eval_steps: 100
      save_strategy: epoch
      save_steps: 100
 
      assistant_only_loss: False   # used only in conversational dataset
      max_length: 1024  # max sequence length
      
    # grpo configs
    grpo:
      max_epochs: 1
      batch_size: 8
      num_generations: 8   
      num_iterations: 1
      grad_accum_steps: 1
      beta: 0.0
      loss_type: bnpo
      
      # optimizer configs
      optim: adamw_torch
      schedular: cosine
      warmup_ratio: 0.1
      lr: 1e-5
      
      # duration
      logging_strategy: steps
      logging_steps: 1
      eval_strategy: epoch
      eval_steps: 100
      save_strategy: epoch
      save_steps: 100

      # rewards
      reward_funcs: 
          - name: accuracy
            weight: 1

  # trigger training configs
  trigger:

    grpo:
      max_epochs: 2
      batch_size: 8
      num_generations: 8
      num_iterations: 1
      grad_accum_steps: 1
      beta: 0.0
      
      # optimizer configs
      optim: adamw_torch
      lr: 1e-5
      schedular: cosine
      warmup_ratio: 0.1
      
      # duration
      logging_strategy: steps
      logging_steps: 1
      eval_strategy: steps
      eval_steps: 100
      save_strategy: steps
      save_steps: 100

      # rewards
      reward_funcs: 
          - name: accuracy
            weight: 1.0
  
  # generation config for GRPO training and evaluation
  generation:
    max_turns: 1
    max_start_length: 1024     # Maximum length of the initial prompt.
    max_prompt_length: 4096    # Maximum prompt length during multi-turn interactions (includes all conversation history across turns).
    max_response_length: 1024
    max_obs_length: 512
    do_sample: False
    temperature: 1.0
    eval_batch_size: 8