szxllm commited on
Commit
263d741
·
verified ·
1 Parent(s): 5963aaa

Upload 3 files

Browse files
Files changed (3) hide show
  1. grpo_dataloader.py +191 -0
  2. grpo_r1_train.py +320 -0
  3. math_verifier.py +270 -0
grpo_dataloader.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO专用数据加载器
3
+ """
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from datasets import load_dataset, interleave_datasets
7
+ from typing import Optional, List
8
+ import logging
9
+ import os
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ from data_config import (
14
+ GRPO_DATASETS,
15
+ GRPO_PROMPT_MIX,
16
+ HF_CACHE_DIR
17
+ )
18
+
19
+
20
+ class GRPOPromptDataset(Dataset):
21
+ """
22
+ GRPO Prompt数据集 - 用于生成阶段
23
+ """
24
+ def __init__(
25
+ self,
26
+ mix_name: str = 'default',
27
+ tokenizer=None,
28
+ max_length: int = 512,
29
+ max_samples: Optional[int] = None
30
+ ):
31
+ super().__init__()
32
+
33
+ if tokenizer is None:
34
+ raise ValueError("tokenizer cannot be None")
35
+
36
+ self.tokenizer = tokenizer
37
+ self.max_length = max_length
38
+
39
+ # 获取混合配置
40
+ if mix_name not in GRPO_PROMPT_MIX:
41
+ raise ValueError(
42
+ f"Unknown mix: {mix_name}. "
43
+ f"Available: {list(GRPO_PROMPT_MIX.keys())}"
44
+ )
45
+
46
+ mix_config = GRPO_PROMPT_MIX[mix_name]
47
+ dataset_names = mix_config.get('datasets', [])
48
+ weights = mix_config.get('weights', [])
49
+
50
+ logger.info(f"Loading GRPO prompt mix: {mix_name}")
51
+ logger.info(f" Datasets: {dataset_names}")
52
+ logger.info(f" Weights: {weights}")
53
+
54
+ # 加载数据集
55
+ all_datasets = []
56
+
57
+ for name in dataset_names:
58
+ if name not in GRPO_DATASETS:
59
+ logger.warning(f"Dataset {name} not found")
60
+ continue
61
+
62
+ config = GRPO_DATASETS[name]
63
+
64
+ # 验证文件存在
65
+ data_file = config.get('data_files')
66
+ if data_file and not os.path.exists(data_file):
67
+ logger.error(f"Data file not found: {data_file}")
68
+ logger.error(f"请先运行 download_grpo_datasets.py 下载数据")
69
+ continue
70
+
71
+ try:
72
+ load_kwargs = {
73
+ 'path': config['hf_path'],
74
+ 'split': config.get('split', 'train'),
75
+ 'cache_dir': HF_CACHE_DIR,
76
+ }
77
+
78
+ if 'data_files' in config:
79
+ load_kwargs['data_files'] = config['data_files']
80
+
81
+ ds = load_dataset(**load_kwargs)
82
+
83
+ # 限制样本数
84
+ if config.get('max_samples'):
85
+ ds = ds.select(range(min(len(ds), config['max_samples'])))
86
+
87
+ all_datasets.append(ds)
88
+ logger.info(f" Loaded {name}: {len(ds)} samples")
89
+
90
+ except Exception as e:
91
+ logger.error(f"Error loading {name}: {e}")
92
+ continue
93
+
94
+ if not all_datasets:
95
+ raise ValueError("No datasets loaded successfully")
96
+
97
+ # 合并数据集
98
+ if len(all_datasets) == 1:
99
+ self.dataset = all_datasets[0]
100
+ else:
101
+ probabilities = [w / sum(weights[:len(all_datasets)])
102
+ for w in weights[:len(all_datasets)]]
103
+ self.dataset = interleave_datasets(
104
+ all_datasets,
105
+ probabilities=probabilities,
106
+ seed=42,
107
+ stopping_strategy='all_exhausted'
108
+ )
109
+
110
+ # 限制总样本数
111
+ if max_samples and len(self.dataset) > max_samples:
112
+ self.dataset = self.dataset.select(range(max_samples))
113
+
114
+ logger.info(f"Total prompts: {len(self.dataset)}")
115
+
116
+ def __len__(self):
117
+ return len(self.dataset)
118
+
119
+ def __getitem__(self, idx):
120
+ try:
121
+ sample = self.dataset[idx]
122
+
123
+ # 提取prompt
124
+ prompt = sample.get('prompt', '')
125
+
126
+ if not prompt:
127
+ logger.warning(f"Empty prompt at index {idx}")
128
+ return None
129
+
130
+ # Tokenize (不添加EOS,因为这是prompt)
131
+ encoding = self.tokenizer(
132
+ prompt,
133
+ max_length=self.max_length,
134
+ truncation=True,
135
+ padding='max_length',
136
+ return_tensors='pt',
137
+ add_special_tokens=True
138
+ )
139
+
140
+ return {
141
+ 'input_ids': encoding['input_ids'].squeeze(0),
142
+ 'attention_mask': encoding['attention_mask'].squeeze(0),
143
+ 'prompt_text': prompt
144
+ }
145
+
146
+ except Exception as e:
147
+ logger.debug(f"Error processing sample {idx}: {e}")
148
+ return None
149
+
150
+
151
+ def grpo_collate_fn(batch):
152
+ """GRPO专用collate函数"""
153
+ # 过滤None
154
+ batch = [item for item in batch if item is not None]
155
+
156
+ if not batch:
157
+ return None
158
+
159
+ return {
160
+ 'input_ids': torch.stack([item['input_ids'] for item in batch]),
161
+ 'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
162
+ 'prompt_texts': [item['prompt_text'] for item in batch]
163
+ }
164
+
165
+
166
+ def create_grpo_prompt_dataloader(
167
+ mix_name: str = 'default',
168
+ tokenizer=None,
169
+ batch_size: int = 4,
170
+ num_workers: int = 2,
171
+ max_length: int = 512,
172
+ max_samples: Optional[int] = None,
173
+ shuffle: bool = True
174
+ ):
175
+ """创建GRPO prompt数据加载器"""
176
+ dataset = GRPOPromptDataset(
177
+ mix_name=mix_name,
178
+ tokenizer=tokenizer,
179
+ max_length=max_length,
180
+ max_samples=max_samples
181
+ )
182
+
183
+ return DataLoader(
184
+ dataset,
185
+ batch_size=batch_size,
186
+ shuffle=shuffle,
187
+ num_workers=num_workers,
188
+ collate_fn=grpo_collate_fn,
189
+ pin_memory=True,
190
+ drop_last=False
191
+ )
grpo_r1_train.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from transformers import AutoTokenizer
6
+ from torch.utils.data import DataLoader, Dataset
7
+ import json
8
+ import logging
9
+ from tqdm import tqdm
10
+ import glob
11
+ from datetime import datetime
12
+ import gc
13
+ from model import MultiModalDenseTransformer
14
+ from grpo import GRPOZeroTrainer
15
+
16
+ # ================= DDP 设置 =================
17
+ def setup_distributed():
18
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
19
+ dist.init_process_group(backend="nccl")
20
+ rank = int(os.environ["RANK"])
21
+ local_rank = int(os.environ["LOCAL_RANK"])
22
+ world_size = int(os.environ["WORLD_SIZE"])
23
+ torch.cuda.set_device(local_rank)
24
+ print(f"Initialized DDP: Rank {rank}/{world_size}")
25
+ return rank, local_rank, world_size
26
+ else:
27
+ print("Initialized Single GPU Mode")
28
+ return 0, 0, 1
29
+
30
+ RANK, LOCAL_RANK, WORLD_SIZE = setup_distributed()
31
+ IS_MAIN = RANK == 0
32
+
33
+ logging.basicConfig(
34
+ level=logging.INFO if IS_MAIN else logging.WARNING,
35
+ format=f'%(asctime)s - [Rank {RANK}] - %(levelname)s - %(message)s'
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # ================= 数据集 =================
40
+ class MathDataset(Dataset):
41
+ def __init__(self, path):
42
+ self.data = []
43
+ with open(path, 'r', encoding='utf-8') as f:
44
+ for line in f:
45
+ if line.strip():
46
+ self.data.append(json.loads(line))
47
+
48
+ def __len__(self):
49
+ return len(self.data)
50
+
51
+ def __getitem__(self, idx):
52
+ return self.data[idx]
53
+
54
+ def math_collate(batch):
55
+ return {
56
+ 'prompt': [item['prompt'] for item in batch],
57
+ 'ground_truth': [item['ground_truth'] for item in batch]
58
+ }
59
+
60
+ # ================= 主函数 =================
61
+ def main():
62
+ # ------------------ 配置区域 ------------------
63
+ CONFIG = {
64
+ # 基础模型路径
65
+ 'sft_checkpoint': '/root/checkpoints/dcpo_posttrain_round3/step_2600.pt',
66
+ 'data_path': '/root/dataset/r1_zero_math.jsonl',
67
+ 'save_dir': '/root/checkpoints/r1_zero_reproduction',
68
+ 'resume_from': None, # 或者具体路径
69
+
70
+ # 模型参数 (需确保与 Checkpoint 一致)
71
+ 'model_dim': 1536,
72
+ 'n_layers': 12,
73
+ 'n_heads': 12,
74
+ 'n_kv_heads': 4,
75
+
76
+ # 训练参数
77
+ 'group_size': 4,
78
+ 'batch_size': 1, # Prompt Batch Size
79
+ 'learning_rate': 2e-6,
80
+ 'max_steps': 190000,
81
+ 'max_gen_len': 512,
82
+ 'save_interval': 300,
83
+
84
+ # 【新增】累积更新参数
85
+ # 实际 Update Batch = batch_size * group_size * accum_steps
86
+ # 例如: 1 * 4 * 8 = 32
87
+ 'gradient_accumulation_steps': 8,
88
+ 'inner_batch_size': 4 # PPO Update 时的显存计算 Batch
89
+ }
90
+ # ---------------------------------------------
91
+
92
+ if IS_MAIN:
93
+ os.makedirs(CONFIG['save_dir'], exist_ok=True)
94
+ current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
95
+ log_file = os.path.join(CONFIG['save_dir'], f"train_{current_time}.log")
96
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
97
+ file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
98
+ logger.addHandler(file_handler)
99
+ logger.info(f"Configuration: {json.dumps(CONFIG, indent=2)}")
100
+
101
+ # 1. 加载 Tokenizer
102
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
103
+ if tokenizer.pad_token is None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+ tokenizer.pad_token_id = tokenizer.eos_token_id
106
+
107
+ # 2. 初始化模型
108
+ def create_model():
109
+ return MultiModalDenseTransformer(
110
+ model_dim=CONFIG['model_dim'],
111
+ vocab_size=len(tokenizer),
112
+ n_layers=CONFIG['n_layers'],
113
+ n_heads=CONFIG['n_heads'],
114
+ n_kv_heads=CONFIG['n_kv_heads'],
115
+ max_seq_len=2048,
116
+ use_gradient_checkpointing=True
117
+ )
118
+
119
+ device = torch.device(f"cuda:{LOCAL_RANK}")
120
+
121
+ logger.info("Initializing Actor Model...")
122
+ actor = create_model().to(device)
123
+
124
+ logger.info("Initializing Ref Model...")
125
+ ref = create_model().to(device)
126
+ ref.eval()
127
+ ref.requires_grad_(False)
128
+
129
+ # 3. 初始化训练器 (传入累积参数)
130
+ trainer = GRPOZeroTrainer(
131
+ actor_model=actor,
132
+ ref_model=ref,
133
+ tokenizer=tokenizer,
134
+ learning_rate=CONFIG['learning_rate'],
135
+ group_size=CONFIG['group_size'],
136
+ use_amp=True,
137
+ gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],
138
+ inner_batch_size=CONFIG['inner_batch_size']
139
+ )
140
+
141
+ # 4. 加载权重与恢复
142
+ start_step = 0
143
+ samples_seen = 0
144
+
145
+ if CONFIG['resume_from']:
146
+ resume_path = CONFIG['resume_from']
147
+ logger.info(f"Resuming from: {resume_path}")
148
+ checkpoint = torch.load(resume_path, map_location='cpu')
149
+
150
+ actor.load_state_dict(checkpoint['model_state_dict'])
151
+ # 恢复优化器
152
+ if 'optimizer_state_dict' in checkpoint:
153
+ try:
154
+ trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
155
+ except Exception as e:
156
+ logger.warning(f"Optimizer load failed (param mismatch?): {e}")
157
+
158
+ ref.load_state_dict(checkpoint['model_state_dict']) # Ref 与 Actor 同步起点
159
+
160
+ start_step = checkpoint.get('step', 0) + 1
161
+ samples_seen = checkpoint.get('samples_seen', start_step * CONFIG['batch_size'] * WORLD_SIZE)
162
+
163
+ del checkpoint
164
+ gc.collect()
165
+ torch.cuda.empty_cache()
166
+ else:
167
+ logger.info(f"Loading SFT checkpoint: {CONFIG['sft_checkpoint']}")
168
+ checkpoint = torch.load(CONFIG['sft_checkpoint'], map_location='cpu')
169
+ state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
170
+ # 去除 module. 前缀
171
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
172
+
173
+ actor.load_state_dict(new_state_dict)
174
+ ref.load_state_dict(new_state_dict)
175
+ del checkpoint, state_dict, new_state_dict
176
+ gc.collect()
177
+ torch.cuda.empty_cache()
178
+
179
+ if WORLD_SIZE > 1:
180
+ actor = DDP(actor, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
181
+
182
+ # 5. 数据加载
183
+ dataset = MathDataset(CONFIG['data_path'])
184
+ if WORLD_SIZE > 1:
185
+ sampler = torch.utils.data.DistributedSampler(
186
+ dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, seed=42
187
+ )
188
+ else:
189
+ sampler = None
190
+
191
+ dataloader = DataLoader(
192
+ dataset, batch_size=CONFIG['batch_size'],
193
+ collate_fn=math_collate, sampler=sampler, shuffle=(sampler is None)
194
+ )
195
+
196
+ # 6. 训练循环
197
+ logger.info(f"Starting Training from step {start_step}...")
198
+
199
+ if sampler:
200
+ epoch = samples_seen // len(dataset)
201
+ sampler.set_epoch(epoch)
202
+
203
+ data_iter = iter(dataloader)
204
+
205
+ # 简单的跳过逻辑
206
+ if samples_seen > 0:
207
+ skip_batches = samples_seen // (CONFIG['batch_size'] * WORLD_SIZE)
208
+ logger.info(f"Skipping {skip_batches} batches...")
209
+ for _ in range(skip_batches):
210
+ try:
211
+ next(data_iter)
212
+ except StopIteration:
213
+ if sampler: sampler.set_epoch(sampler.epoch + 1)
214
+ data_iter = iter(dataloader)
215
+ next(data_iter)
216
+
217
+ progress_bar = tqdm(range(start_step, CONFIG['max_steps']), disable=not IS_MAIN, initial=start_step, total=CONFIG['max_steps'])
218
+
219
+ # 状态追踪
220
+ current_samples = samples_seen
221
+ running_reward = 0.0
222
+ running_loss = 0.0
223
+
224
+ for step in progress_bar:
225
+ try:
226
+ try:
227
+ batch = next(data_iter)
228
+ except StopIteration:
229
+ if sampler:
230
+ epoch = current_samples // len(dataset)
231
+ sampler.set_epoch(epoch)
232
+ data_iter = iter(dataloader)
233
+ batch = next(data_iter)
234
+
235
+ current_samples += CONFIG['batch_size'] * WORLD_SIZE
236
+
237
+ # 生成阶段
238
+ experience = trainer.generate_and_score(
239
+ batch,
240
+ max_gen_len=CONFIG['max_gen_len']
241
+ )
242
+
243
+ # 记录 Reward (平滑)
244
+ step_reward = experience['avg_reward']
245
+ if running_reward == 0: running_reward = step_reward
246
+ else: running_reward = 0.95 * running_reward + 0.05 * step_reward
247
+
248
+ # 训练阶段 (可能返回 None)
249
+ loss = trainer.train_step(experience)
250
+
251
+ # 日志与显示逻辑
252
+ status_dict = {"R": f"{running_reward:.3f}"}
253
+
254
+ if loss is not None:
255
+ # 发生了权重更新
256
+ if running_loss == 0: running_loss = loss
257
+ else: running_loss = 0.9 * running_loss + 0.1 * loss
258
+ status_dict["L"] = f"{running_loss:.3f}"
259
+
260
+ if IS_MAIN:
261
+ # 写入 Metrics
262
+ current_lr = trainer.optimizer.param_groups[0]['lr']
263
+ metrics_data = {
264
+ "step": step,
265
+ "reward": float(step_reward), # 记录当前步的 reward
266
+ "loss": float(loss),
267
+ "lr": float(current_lr),
268
+ "samples_seen": current_samples,
269
+ "timestamp": datetime.now().isoformat()
270
+ }
271
+ with open(os.path.join(CONFIG['save_dir'], "metrics.jsonl"), "a") as f:
272
+ f.write(json.dumps(metrics_data) + "\n")
273
+
274
+ if step % 10 == 0:
275
+ logger.info(f"Step {step} | Reward: {step_reward:.4f} | Loss: {loss:.4f} | LR: {current_lr:.2e}")
276
+ else:
277
+ # 正在累积
278
+ status_dict["State"] = "Acc"
279
+
280
+ progress_bar.set_description(f"{' '.join([f'{k}:{v}' for k,v in status_dict.items()])}")
281
+
282
+ # 保存逻辑
283
+ if step > 0 and step % CONFIG['save_interval'] == 0 and IS_MAIN:
284
+ save_path = f"{CONFIG['save_dir']}/step_{step}.pt"
285
+ model_to_save = actor.module if hasattr(actor, 'module') else actor
286
+ torch.save({
287
+ 'step': step,
288
+ 'samples_seen': current_samples,
289
+ 'model_state_dict': model_to_save.state_dict(),
290
+ 'optimizer_state_dict': trainer.optimizer.state_dict(),
291
+ }, save_path)
292
+ logger.info(f"Checkpoint saved: {save_path}")
293
+
294
+ # 显存清理
295
+ del experience
296
+ del batch
297
+ # 这里的 empty_cache 是可选的,如果显存非常紧张建议开启
298
+ # torch.cuda.empty_cache()
299
+
300
+ except Exception as e:
301
+ logger.error(f"Step {step} Error: {e}")
302
+ import traceback
303
+ traceback.print_exc()
304
+ continue
305
+
306
+ # 结束保存
307
+ if IS_MAIN:
308
+ final_path = f"{CONFIG['save_dir']}/final_r1_zero.pt"
309
+ model_to_save = actor.module if hasattr(actor, 'module') else actor
310
+ torch.save({
311
+ 'step': CONFIG['max_steps'],
312
+ 'model_state_dict': model_to_save.state_dict(),
313
+ }, final_path)
314
+ logger.info("Training Finished.")
315
+
316
+ if WORLD_SIZE > 1:
317
+ dist.destroy_process_group()
318
+
319
+ if __name__ == "__main__":
320
+ main()
math_verifier.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import math
3
+ import logging
4
+ from difflib import SequenceMatcher
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class MathReward:
9
+ def __init__(self, use_reference_comparison=True):
10
+ """
11
+ Args:
12
+ use_reference_comparison: 是否使用参考答案进行推理过程比较
13
+ """
14
+ # 编译正则表达式,强制要求 <think> 在前,<answer> 在后
15
+ self.format_pattern = re.compile(r"<think>(.*?)</think>\s*<answer>(.*?)</answer>", re.DOTALL)
16
+ self.use_reference_comparison = use_reference_comparison
17
+
18
+ # 推理关键词(用于检查推理质量)
19
+ self.reasoning_keywords = [
20
+ '计算', '因为', '所以', '首先', '然后', '接着', '最后', '根据',
21
+ '第一步', '第二步', '第三步', '第', '步', '得到', '等于',
22
+ '加', '减', '乘', '除', '=', '+', '-', '*', '/', '÷', '×'
23
+ ]
24
+
25
+ def parse_number(self, text):
26
+ """
27
+ 从文本中解析数值。
28
+ 支持:整数、小数、分数(1/5)、百分数(20%)、带逗号的数字(1,000)
29
+ """
30
+ if not text:
31
+ return None
32
+
33
+ # 预处理:移除空格、货币符号、常见的中文单位
34
+ text = text.strip()
35
+ clean_text = text.replace(" ", "").replace(",", "").replace("¥", "").replace("$", "")
36
+ clean_text = clean_text.replace("千克", "").replace("元", "").replace("个", "").replace("只", "")
37
+ clean_text = clean_text.replace("本", "").replace("米", "").replace("人", "")
38
+
39
+ try:
40
+ # 1. 处理百分数 (e.g., "20%")
41
+ if "%" in clean_text:
42
+ return float(clean_text.replace("%", "")) / 100
43
+
44
+ # 2. 处理分数 (e.g., "1/5" 或 "42/5")
45
+ if "/" in clean_text:
46
+ parts = clean_text.split("/")
47
+ if len(parts) == 2:
48
+ try:
49
+ return float(parts[0]) / float(parts[1])
50
+ except:
51
+ pass
52
+
53
+ # 3. 处理科学记数法 (e.g., "1.5e-3")
54
+ if "e" in clean_text.lower() or "E" in clean_text:
55
+ return float(clean_text)
56
+
57
+ # 4. 提取所有匹配的数字格式
58
+ # 匹配 浮点数 或 整数,忽略可能混杂的文字
59
+ matches = re.findall(r"[-+]?\d*\.\d+|\d+", clean_text)
60
+ if matches:
61
+ # 取最后一个作为最终答案(通常答案在最后)
62
+ return float(matches[-1])
63
+
64
+ except Exception as e:
65
+ logger.debug(f"解析数字失败: {text}, 错误: {e}")
66
+
67
+ return None
68
+
69
+ def check_reasoning_quality(self, think_content):
70
+ """
71
+ 检查推理过程的质量
72
+
73
+ 返回质量评分 (0.0 - 1.0)
74
+ """
75
+ if not think_content:
76
+ return 0.0
77
+
78
+ quality_score = 0.0
79
+
80
+ # 1. 长度检查(基础)
81
+ length = len(think_content)
82
+ if length >= 100:
83
+ quality_score += 0.3
84
+ elif length >= 50:
85
+ quality_score += 0.15
86
+
87
+ # 2. 关键词检查(推理步骤标识)
88
+ keyword_count = sum(1 for kw in self.reasoning_keywords if kw in think_content)
89
+ # 每出现一个关键词加分,最多加0.3分
90
+ quality_score += min(keyword_count * 0.05, 0.3)
91
+
92
+ # 3. 数学表达式检查(是否包含计算过程)
93
+ # 匹配数学运算符或等式
94
+ math_expressions = re.findall(r'\d+\s*[+\-*/×÷=]\s*\d+', think_content)
95
+ if len(math_expressions) > 0:
96
+ quality_score += 0.2
97
+ # 多个表达式说明推理更详细
98
+ if len(math_expressions) >= 3:
99
+ quality_score += 0.1
100
+
101
+ # 4. 结构检查(是否有步骤分隔)
102
+ has_steps = bool(re.search(r'第\d+步|步骤\d+|^\d+[.、]', think_content, re.MULTILINE))
103
+ if has_steps:
104
+ quality_score += 0.1
105
+
106
+ return min(quality_score, 1.0)
107
+
108
+ def compute_reasoning_similarity(self, generated_reasoning, reference_reasoning):
109
+ """
110
+ 计算生成的推理过程与参考推理过程的相似度
111
+
112
+ 使用序列匹配算法(考虑顺序)
113
+ 返回相似度分数 (0.0 - 1.0)
114
+ """
115
+ if not generated_reasoning or not reference_reasoning:
116
+ return 0.0
117
+
118
+ # 使用 difflib 的 SequenceMatcher 计算相似度
119
+ similarity = SequenceMatcher(None, generated_reasoning, reference_reasoning).ratio()
120
+
121
+ return similarity
122
+
123
+ def compute_rewards(self, completions, ground_truths):
124
+ """
125
+ 计算奖励
126
+
127
+ Args:
128
+ completions: List[str] 模型生成的完整文本
129
+ ground_truths: List[dict] 对应的真值
130
+ 必须包含: 'answer_val': float
131
+ 可选包含: 'reasoning': str, 'reference_completion': str
132
+
133
+ Returns:
134
+ rewards: List[float]
135
+ """
136
+ rewards = []
137
+
138
+ for completion, gt in zip(completions, ground_truths):
139
+ total_reward = 0.0
140
+
141
+ # --- 1. 格式与结构检查 ---
142
+ match = self.format_pattern.search(completion)
143
+
144
+ # 如果没有匹配到 <think>...</think><answer>...</answer> 结构
145
+ if match is None:
146
+ # 格式严重错误,给予重罚
147
+ rewards.append(-2.0)
148
+ continue
149
+
150
+ # 提取内容
151
+ think_content = match.group(1).strip()
152
+ answer_content = match.group(2).strip()
153
+
154
+ # 格式正确的基础分
155
+ total_reward += 0.6
156
+
157
+ # --- 2. 思考过程质量检查 ---
158
+ reasoning_quality = self.check_reasoning_quality(think_content)
159
+
160
+ if reasoning_quality < 0.3:
161
+ # 推理过程质量太低(可能是敷衍或格式化)
162
+ total_reward -= 0.5
163
+ else:
164
+ # 推理质量越高,奖励越多
165
+ total_reward += reasoning_quality * 1.0 # 最多1.0分
166
+
167
+ # --- 3. 推理过程与参考对比(如果有参考) ---
168
+ if self.use_reference_comparison and 'reasoning' in gt:
169
+ reference_reasoning = gt['reasoning']
170
+ similarity = self.compute_reasoning_similarity(think_content, reference_reasoning)
171
+
172
+ # 相似度奖励(最多0.5分)
173
+ # 注意:不要求完全一致,因为可能有多种正确推理方式
174
+ if similarity > 0.3:
175
+ total_reward += similarity * 0.5
176
+
177
+ # --- 4. 答案准确性检查(最重要) ---
178
+ pred_val = self.parse_number(answer_content)
179
+ gt_val = gt['answer_val']
180
+
181
+ if pred_val is not None:
182
+ # 数值比较,允许 float 精度误差
183
+ if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
184
+ # 答对给予最高奖励
185
+ total_reward += 3.0
186
+ else:
187
+ # 答错扣分
188
+ # 根据误差大小调整惩罚
189
+ try:
190
+ relative_error = abs(pred_val - gt_val) / (abs(gt_val) + 1e-8)
191
+ if relative_error < 0.1:
192
+ # 接近正确答案,轻微惩罚
193
+ total_reward -= 0.3
194
+ elif relative_error < 0.5:
195
+ # 有一定误差
196
+ total_reward -= 0.8
197
+ else:
198
+ # 完全错误
199
+ total_reward -= 1.5
200
+ except:
201
+ total_reward -= 1.5
202
+ else:
203
+ # <answer> 标签内提取不到有效数字
204
+ total_reward -= 1.0
205
+
206
+ # --- 5. 一致性检查:推理过程中的数字应该与答案相关 ---
207
+ # 提取推理过程中出现的所有数字
208
+ reasoning_numbers = re.findall(r'[-+]?\d*\.\d+|\d+', think_content)
209
+ if reasoning_numbers and pred_val is not None:
210
+ # 检查答案是否出现在推理过程中
211
+ answer_in_reasoning = any(
212
+ math.isclose(float(num), pred_val, rel_tol=1e-3, abs_tol=1e-3)
213
+ for num in reasoning_numbers
214
+ )
215
+ if answer_in_reasoning:
216
+ total_reward += 0.2
217
+
218
+ rewards.append(total_reward)
219
+
220
+ return rewards
221
+
222
+ def compute_metrics(self, completions, ground_truths):
223
+ """
224
+ 计算详细的评估指标(用于分析)
225
+
226
+ Returns:
227
+ dict: 包含各种指标的字典
228
+ """
229
+ metrics = {
230
+ 'format_correct': 0,
231
+ 'reasoning_quality_avg': 0.0,
232
+ 'answer_correct': 0,
233
+ 'answer_close': 0, # 答案接近但不完全正确
234
+ 'total': len(completions)
235
+ }
236
+
237
+ quality_scores = []
238
+
239
+ for completion, gt in zip(completions, ground_truths):
240
+ match = self.format_pattern.search(completion)
241
+
242
+ if match:
243
+ metrics['format_correct'] += 1
244
+
245
+ think_content = match.group(1).strip()
246
+ answer_content = match.group(2).strip()
247
+
248
+ # 推理质量
249
+ quality = self.check_reasoning_quality(think_content)
250
+ quality_scores.append(quality)
251
+
252
+ # 答案准确性
253
+ pred_val = self.parse_number(answer_content)
254
+ gt_val = gt['answer_val']
255
+
256
+ if pred_val is not None and gt_val is not None:
257
+ if math.isclose(pred_val, gt_val, rel_tol=1e-4, abs_tol=1e-4):
258
+ metrics['answer_correct'] += 1
259
+ elif math.isclose(pred_val, gt_val, rel_tol=0.1, abs_tol=0.1):
260
+ metrics['answer_close'] += 1
261
+
262
+ if quality_scores:
263
+ metrics['reasoning_quality_avg'] = sum(quality_scores) / len(quality_scores)
264
+
265
+ # 计算百分比
266
+ metrics['format_correct_pct'] = metrics['format_correct'] / metrics['total'] * 100
267
+ metrics['answer_correct_pct'] = metrics['answer_correct'] / metrics['total'] * 100
268
+ metrics['answer_close_pct'] = metrics['answer_close'] / metrics['total'] * 100
269
+
270
+ return metrics