Gavin-Wang commited on
Commit
2dac41e
·
verified ·
1 Parent(s): 96a1563

Upload train_abstract_grpo_gaussian.py

Browse files
Files changed (1) hide show
  1. train_abstract_grpo_gaussian.py +267 -0
train_abstract_grpo_gaussian.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gaussian GRPO Training for Abstract Tokens
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import torch
9
+ import glob
10
+ import os
11
+ import sys
12
+ import numpy as np
13
+ from pathlib import Path
14
+ from tqdm import tqdm
15
+ from torch.optim import AdamW
16
+ from abstract_model import AbstractModel
17
+
18
+ def extract_bracketed_answer(text):
19
+ """Extract answer from [FINAL ANSWER: X] format."""
20
+ import re
21
+ match = re.search(r'\[FINAL ANSWER:\s*(.*?)\]', text, re.IGNORECASE)
22
+ if match:
23
+ return match.group(1).strip()
24
+ return None
25
+
26
+ def normalize_answer(s):
27
+ """Normalize answer for robust comparison."""
28
+ import re
29
+ import string
30
+ s = str(s).strip().lower()
31
+ s = re.sub(r'\\$\\$.*?$\\`', '', s)
32
+ s = re.sub(r'\\$', '', s)
33
+ s = re.sub(r'\\text\{(.*?)\}', r'\1', s)
34
+ s = s.translate(str.maketrans('', '', string.punctuation))
35
+ return ' '.join(s.split())
36
+
37
+ def compute_reward(generated_text, reference_answer, mode_sequence):
38
+ """
39
+ Compute composite reward: Accuracy + Structure
40
+ """
41
+ bracketed = extract_bracketed_answer(generated_text)
42
+ gen_to_compare = bracketed if bracketed else generated_text
43
+
44
+ gen_norm = normalize_answer(gen_to_compare)
45
+ ref_norm = normalize_answer(reference_answer)
46
+
47
+ if not ref_norm:
48
+ answer_score = 0.0
49
+ elif gen_norm == ref_norm:
50
+ answer_score = 1.0
51
+ elif ref_norm in gen_norm.split():
52
+ answer_score = 1.0
53
+ else:
54
+ # Partial overlap
55
+ gen_words = set(gen_norm.split())
56
+ ref_words = set(ref_norm.split())
57
+ if len(ref_words) > 0:
58
+ answer_score = len(gen_words & ref_words) / len(ref_words)
59
+ else:
60
+ answer_score = 0.0
61
+
62
+ # Structure Reward: Did it use </think>?
63
+ has_transition = 'T' in mode_sequence
64
+ structure_score = 1.0 if has_transition else 0.0
65
+
66
+ total_reward = (0.7 * answer_score) + (0.3 * structure_score)
67
+
68
+ # Boost for perfect result
69
+ if answer_score == 1.0 and has_transition:
70
+ total_reward = 1.0
71
+
72
+ return total_reward
73
+
74
+ SYSTEM_PROMPTS = {
75
+ 'boolean_expressions': "Provide your final answer (True or False) at the end of your response in this exact format: [FINAL ANSWER: X].",
76
+ 'dyck_language': "Provide the completion sequence at the end of your response in this exact format: [FINAL ANSWER: X].",
77
+ 'causal_judgement': "Provide your answer (Yes or No) at the end of your response in this exact format: [FINAL ANSWER: X].",
78
+ 'formal_fallacies': "Provide your answer at the end of your response in this exact format: [FINAL ANSWER: X].",
79
+ 'logical_deduction_three_objects': "Provide your final answer at the end of your response in this exact format: [FINAL ANSWER: X].",
80
+ 'math_level1': "Provide your final numerical answer at the end of your response in this exact format: [FINAL ANSWER: X].",
81
+ 'prontoqa': "Provide your answer (True or False) at the end of your response in this exact format: [FINAL ANSWER: X].",
82
+ 'temporal_sequences': "Provide the next element(s) in the sequence at the end of your response in this exact format: [FINAL ANSWER: X].",
83
+ 'tracking_shuffled_objects_three_objects': "Provide your answer at the end of your response in this exact format: [FINAL ANSWER: X].",
84
+ 'web_of_lies': "Provide your answer (True or False) at the end of your response in this exact format: [FINAL ANSWER: X].",
85
+ }
86
+
87
+ def load_rl_data(data_dir, max_samples=None):
88
+ all_data = []
89
+ files = glob.glob(os.path.join(data_dir, "*.jsonl"))
90
+
91
+ print(f"Scanning {data_dir}...")
92
+ if not files:
93
+ print(f"CRITICAL WARNING: No .jsonl files found in {data_dir}")
94
+ return []
95
+
96
+ print(f"Found {len(files)} files.")
97
+
98
+ for f_path in files:
99
+ filename = os.path.basename(f_path).replace('.jsonl', '')
100
+ system_prompt = SYSTEM_PROMPTS.get(filename, None)
101
+
102
+ # Fuzzy match system prompt
103
+ if system_prompt is None:
104
+ filename_alt = filename.replace('_', '')
105
+ for key in SYSTEM_PROMPTS:
106
+ if key.replace('_', '') == filename_alt:
107
+ system_prompt = SYSTEM_PROMPTS[key]
108
+ break
109
+
110
+ try:
111
+ with open(f_path, 'r') as f:
112
+ for line in f:
113
+ try:
114
+ item = json.loads(line)
115
+ if 'prompt' in item and 'answer' in item:
116
+ item['system_prompt'] = system_prompt
117
+ all_data.append(item)
118
+ except:
119
+ continue
120
+ except Exception as e:
121
+ print(f"Error reading {f_path}: {e}")
122
+
123
+ if max_samples:
124
+ import random
125
+ random.shuffle(all_data)
126
+ all_data = all_data[:max_samples]
127
+
128
+ print(f"Loaded {len(all_data)} valid training samples.")
129
+ return all_data
130
+
131
+ def train(args):
132
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
133
+ print(f"Device: {device}")
134
+
135
+ train_data = load_rl_data(args.data_dir, args.max_samples)
136
+ if not train_data:
137
+ print("ERROR: Training data is empty. Exiting.")
138
+ sys.exit(1)
139
+
140
+ print(f"Loading Abstract model from {args.abstract_model}...")
141
+ model = AbstractModel.load_from_directory(args.abstract_model, args.sft_model, device=device)
142
+ try:
143
+ print("Compiling model backbone with torch.compile...")
144
+ model.model_backbone = torch.compile(model.model_backbone)
145
+ except Exception as e:
146
+ print(f"Warning: Could not compile model: {e}")
147
+ model.set_trainable_params('abstract')
148
+ model.train()
149
+
150
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
151
+
152
+ global_step = 0
153
+
154
+ for epoch in range(args.epochs):
155
+ print(f"Epoch {epoch+1}/{args.epochs}")
156
+
157
+ np.random.shuffle(train_data)
158
+ progress = tqdm(range(0, len(train_data), args.batch_size))
159
+
160
+ for i in progress:
161
+ batch = train_data[i : i + args.batch_size]
162
+ batch_loss = 0.0
163
+
164
+ optimizer.zero_grad()
165
+
166
+ for item in batch:
167
+ try:
168
+ prompt = item['prompt']
169
+ reference = item['answer']
170
+ sys_prompt = item.get('system_prompt', None)
171
+
172
+ messages = []
173
+ if sys_prompt:
174
+ messages.append({"role": "system", "content": sys_prompt})
175
+ messages.append({"role": "user", "content": prompt})
176
+
177
+ formatted_prompt = model.tokenizer.apply_chat_template(
178
+ messages, tokenize=False, add_generation_prompt=True
179
+ )
180
+
181
+ input_ids = model.tokenizer(
182
+ formatted_prompt, return_tensors='pt', add_special_tokens=False
183
+ )['input_ids'].to(model.device).squeeze(0)
184
+
185
+ group_rewards = []
186
+ group_log_probs = []
187
+
188
+ for gen_idx in range(args.group_size):
189
+ result = model.forward(
190
+ input_ids,
191
+ max_length=args.max_length,
192
+ temperature=args.temperature,
193
+ sigma=args.sigma,
194
+ sample=True,
195
+ no_grad=False
196
+ )
197
+
198
+ gen_ids = result['generated_tokens'].tolist()
199
+ gen_text = model.tokenizer.decode(gen_ids, skip_special_tokens=True)
200
+ r = compute_reward(gen_text, reference, result['mode_sequence'])
201
+ group_rewards.append(r)
202
+
203
+ if len(result['log_probs']) > 0:
204
+ group_log_probs.append(result['log_probs'].sum())
205
+ else:
206
+ group_log_probs.append(torch.tensor(0.0, device=model.device, requires_grad=True))
207
+
208
+ rewards_np = np.array(group_rewards)
209
+ mean_r = rewards_np.mean()
210
+ std_r = rewards_np.std() + 1e-8
211
+ advantages = (rewards_np - mean_r) / std_r
212
+
213
+ prompt_loss = 0.0
214
+ valid_items = 0
215
+
216
+ for adv, log_prob_sum in zip(advantages, group_log_probs):
217
+ if log_prob_sum.requires_grad:
218
+ adv_tensor = torch.tensor(adv, device=model.device, dtype=log_prob_sum.dtype)
219
+ prompt_loss += -1.0 * (adv_tensor * log_prob_sum)
220
+ valid_items += 1
221
+
222
+ if valid_items > 0:
223
+ prompt_loss = prompt_loss / valid_items
224
+ prompt_loss.backward()
225
+ batch_loss += prompt_loss.item()
226
+
227
+ except Exception as e:
228
+ print(f"Error in batch: {e}")
229
+ continue
230
+
231
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
232
+ optimizer.step()
233
+ global_step += 1
234
+ optimizer.zero_grad()
235
+ torch.cuda.empty_cache()
236
+
237
+ if global_step % args.save_steps == 0:
238
+ save_path = os.path.join(args.output, f"step_{global_step}")
239
+ model.save_to_directory(save_path)
240
+
241
+ model.save_to_directory(os.path.join(args.output, "final"))
242
+ print("Done.")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ torch.set_float32_matmul_precision('high')
247
+ parser = argparse.ArgumentParser()
248
+ parser.add_argument("--sft-model", required=True)
249
+ parser.add_argument("--abstract-model", required=True)
250
+ parser.add_argument("--data-dir", required=True)
251
+ parser.add_argument("--output", required=True)
252
+
253
+ parser.add_argument("--group-size", type=int, default=4)
254
+ parser.add_argument("--batch-size", type=int, default=1)
255
+ parser.add_argument("--lr", type=float, default=1e-5)
256
+ parser.add_argument("--epochs", type=int, default=1)
257
+ parser.add_argument("--max-length", type=int, default=256)
258
+ parser.add_argument("--temperature", type=float, default=1.0)
259
+ parser.add_argument("--sigma", type=float, default=0.1)
260
+
261
+ parser.add_argument("--max-samples", type=int, default=None)
262
+ parser.add_argument("--save-steps", type=int, default=50)
263
+
264
+ args = parser.parse_args()
265
+
266
+ os.makedirs(args.output, exist_ok=True)
267
+ train(args)