szxllm commited on
Commit
5c5e75b
·
verified ·
1 Parent(s): 121c049

Update reward_model.py

Browse files
Files changed (1) hide show
  1. reward_model.py +185 -188
reward_model.py CHANGED
@@ -1,189 +1,186 @@
1
- """
2
- 奖励模型 - 用于RLHF
3
- """
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torch.optim as optim
8
- from torch.utils.data import DataLoader
9
- from collections import defaultdict
10
- from typing import Dict, Tuple, Union, Optional
11
- from tqdm import tqdm
12
- from model import MultiModalDenseTransformer
13
-
14
- class RewardModel(nn.Module):
15
- """奖励模型 - 用于RLHF"""
16
- def __init__(
17
- self,
18
- base_model: MultiModalDenseTransformer,
19
- use_value_head: bool = True
20
- ):
21
- super().__init__()
22
- self.base_model = base_model
23
- self.use_value_head = use_value_head
24
-
25
- self.reward_head = nn.Sequential(
26
- nn.Linear(base_model.model_dim, base_model.model_dim // 2),
27
- nn.ReLU(),
28
- nn.Dropout(0.1),
29
- nn.Linear(base_model.model_dim // 2, 1)
30
- )
31
-
32
- if use_value_head:
33
- self.value_head = nn.Sequential(
34
- nn.Linear(base_model.model_dim, base_model.model_dim // 2),
35
- nn.ReLU(),
36
- nn.Dropout(0.1),
37
- nn.Linear(base_model.model_dim // 2, 1)
38
- )
39
-
40
- def forward(
41
- self,
42
- input_data: Dict,
43
- return_values: bool = False
44
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
45
- """前向传播"""
46
- output = self.base_model(input_data, return_hidden=True)
47
- hidden_states = output['last_hidden_state']
48
-
49
- rewards = self.reward_head(hidden_states).squeeze(-1)
50
-
51
- if return_values and self.use_value_head:
52
- values = self.value_head(hidden_states).squeeze(-1)
53
- return rewards, values
54
-
55
- return rewards
56
-
57
- class RewardModelTrainer:
58
- """奖励模型训练器"""
59
- def __init__(
60
- self,
61
- reward_model: RewardModel,
62
- learning_rate: float = 1e-5,
63
- margin: float = 0.0
64
- ):
65
- self.reward_model = reward_model
66
- self.margin = margin
67
-
68
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
- self.reward_model.to(self.device)
70
-
71
- for param in self.reward_model.base_model.parameters():
72
- param.requires_grad = False
73
-
74
- for layer in self.reward_model.base_model.layers[-2:]:
75
- for param in layer.parameters():
76
- param.requires_grad = True
77
-
78
- trainable_params = list(self.reward_model.reward_head.parameters())
79
- if self.reward_model.use_value_head:
80
- trainable_params += list(self.reward_model.value_head.parameters())
81
-
82
- self.optimizer = optim.AdamW(
83
- filter(lambda p: p.requires_grad, self.reward_model.parameters()),
84
- lr=learning_rate
85
- )
86
-
87
- def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict:
88
- """单步训练"""
89
- self.reward_model.train()
90
- self.optimizer.zero_grad()
91
-
92
- chosen_rewards = self.reward_model(chosen_batch)[:, -1]
93
- rejected_rewards = self.reward_model(rejected_batch)[:, -1]
94
-
95
- loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
96
-
97
- loss.backward()
98
- torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0)
99
- self.optimizer.step()
100
-
101
- accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
102
-
103
- return {
104
- 'loss': loss.item(),
105
- 'accuracy': accuracy
106
- }
107
-
108
- def train(
109
- self,
110
- dataloader: DataLoader,
111
- num_epochs: int = 1,
112
- log_interval: int = 10
113
- ):
114
- """训练循环"""
115
- print(f"Starting reward model training on {self.device}...")
116
-
117
- for epoch in range(num_epochs):
118
- total_stats = defaultdict(float)
119
- num_steps = 0
120
- progress_bar = tqdm(
121
- dataloader,
122
- desc=f"Reward Model Epoch {epoch+1}/{num_epochs}"
123
- )
124
-
125
- for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar):
126
- chosen_batch = {
127
- 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
128
- }
129
-
130
- rejected_batch = {
131
- 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
132
- }
133
-
134
- stats = self.train_step(chosen_batch, rejected_batch)
135
-
136
- for k, v in stats.items():
137
- total_stats[k] += v
138
- num_steps += 1
139
-
140
- if (batch_idx + 1) % log_interval == 0:
141
- avg_stats = {
142
- k: v / num_steps
143
- for k, v in total_stats.items()
144
- }
145
- progress_bar.set_postfix(avg_stats)
146
- total_stats = defaultdict(float)
147
-
148
- print("Reward model training complete!")
149
-
150
- def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
151
- """评估奖励模型"""
152
- self.reward_model.eval()
153
- total_stats = defaultdict(float)
154
- num_batches = 0
155
-
156
- with torch.no_grad():
157
- for chosen_ids, rejected_ids in dataloader:
158
- chosen_batch = {
159
- 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
160
- }
161
-
162
- rejected_batch = {
163
- 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
164
- }
165
-
166
- chosen_rewards = self.reward_model(chosen_batch)[:, -1]
167
- rejected_rewards = self.reward_model(rejected_batch)[:, -1]
168
-
169
- loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
170
- accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
171
-
172
- total_stats['loss'] += loss.item()
173
- total_stats['accuracy'] += accuracy
174
- num_batches += 1
175
-
176
- return {k: v / num_batches for k, v in total_stats.items()}
177
-
178
- def save_checkpoint(self, path: str):
179
- """保存检查点"""
180
- torch.save({
181
- 'model_state_dict': self.reward_model.state_dict(),
182
- 'optimizer_state_dict': self.optimizer.state_dict(),
183
- }, path)
184
-
185
- def load_checkpoint(self, path: str):
186
- """加载检查点"""
187
- checkpoint = torch.load(path, map_location=self.device)
188
- self.reward_model.load_state_dict(checkpoint['model_state_dict'])
189
  self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from collections import defaultdict
7
+ from typing import Dict, Tuple, Union, Optional
8
+ from tqdm import tqdm
9
+ from model import MultiModalDenseTransformer
10
+
11
+ class RewardModel(nn.Module):
12
+ """奖励模型 - 用于RLHF"""
13
+ def __init__(
14
+ self,
15
+ base_model: MultiModalDenseTransformer,
16
+ use_value_head: bool = True
17
+ ):
18
+ super().__init__()
19
+ self.base_model = base_model
20
+ self.use_value_head = use_value_head
21
+
22
+ self.reward_head = nn.Sequential(
23
+ nn.Linear(base_model.model_dim, base_model.model_dim // 2),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.1),
26
+ nn.Linear(base_model.model_dim // 2, 1)
27
+ )
28
+
29
+ if use_value_head:
30
+ self.value_head = nn.Sequential(
31
+ nn.Linear(base_model.model_dim, base_model.model_dim // 2),
32
+ nn.ReLU(),
33
+ nn.Dropout(0.1),
34
+ nn.Linear(base_model.model_dim // 2, 1)
35
+ )
36
+
37
+ def forward(
38
+ self,
39
+ input_data: Dict,
40
+ return_values: bool = False
41
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
42
+ """前向传播"""
43
+ output = self.base_model(input_data, return_hidden=True)
44
+ hidden_states = output['last_hidden_state']
45
+
46
+ rewards = self.reward_head(hidden_states).squeeze(-1)
47
+
48
+ if return_values and self.use_value_head:
49
+ values = self.value_head(hidden_states).squeeze(-1)
50
+ return rewards, values
51
+
52
+ return rewards
53
+
54
+ class RewardModelTrainer:
55
+ """奖励模型训练器"""
56
+ def __init__(
57
+ self,
58
+ reward_model: RewardModel,
59
+ learning_rate: float = 1e-5,
60
+ margin: float = 0.0
61
+ ):
62
+ self.reward_model = reward_model
63
+ self.margin = margin
64
+
65
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66
+ self.reward_model.to(self.device)
67
+
68
+ for param in self.reward_model.base_model.parameters():
69
+ param.requires_grad = False
70
+
71
+ for layer in self.reward_model.base_model.layers[-2:]:
72
+ for param in layer.parameters():
73
+ param.requires_grad = True
74
+
75
+ trainable_params = list(self.reward_model.reward_head.parameters())
76
+ if self.reward_model.use_value_head:
77
+ trainable_params += list(self.reward_model.value_head.parameters())
78
+
79
+ self.optimizer = optim.AdamW(
80
+ filter(lambda p: p.requires_grad, self.reward_model.parameters()),
81
+ lr=learning_rate
82
+ )
83
+
84
+ def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict:
85
+ """单步训练"""
86
+ self.reward_model.train()
87
+ self.optimizer.zero_grad()
88
+
89
+ chosen_rewards = self.reward_model(chosen_batch)[:, -1]
90
+ rejected_rewards = self.reward_model(rejected_batch)[:, -1]
91
+
92
+ loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
93
+
94
+ loss.backward()
95
+ torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0)
96
+ self.optimizer.step()
97
+
98
+ accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
99
+
100
+ return {
101
+ 'loss': loss.item(),
102
+ 'accuracy': accuracy
103
+ }
104
+
105
+ def train(
106
+ self,
107
+ dataloader: DataLoader,
108
+ num_epochs: int = 1,
109
+ log_interval: int = 10
110
+ ):
111
+ """训练循环"""
112
+ print(f"Starting reward model training on {self.device}...")
113
+
114
+ for epoch in range(num_epochs):
115
+ total_stats = defaultdict(float)
116
+ num_steps = 0
117
+ progress_bar = tqdm(
118
+ dataloader,
119
+ desc=f"Reward Model Epoch {epoch+1}/{num_epochs}"
120
+ )
121
+
122
+ for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar):
123
+ chosen_batch = {
124
+ 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
125
+ }
126
+
127
+ rejected_batch = {
128
+ 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
129
+ }
130
+
131
+ stats = self.train_step(chosen_batch, rejected_batch)
132
+
133
+ for k, v in stats.items():
134
+ total_stats[k] += v
135
+ num_steps += 1
136
+
137
+ if (batch_idx + 1) % log_interval == 0:
138
+ avg_stats = {
139
+ k: v / num_steps
140
+ for k, v in total_stats.items()
141
+ }
142
+ progress_bar.set_postfix(avg_stats)
143
+ total_stats = defaultdict(float)
144
+
145
+ print("Reward model training complete!")
146
+
147
+ def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
148
+ """评估奖励模型"""
149
+ self.reward_model.eval()
150
+ total_stats = defaultdict(float)
151
+ num_batches = 0
152
+
153
+ with torch.no_grad():
154
+ for chosen_ids, rejected_ids in dataloader:
155
+ chosen_batch = {
156
+ 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
157
+ }
158
+
159
+ rejected_batch = {
160
+ 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
161
+ }
162
+
163
+ chosen_rewards = self.reward_model(chosen_batch)[:, -1]
164
+ rejected_rewards = self.reward_model(rejected_batch)[:, -1]
165
+
166
+ loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
167
+ accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
168
+
169
+ total_stats['loss'] += loss.item()
170
+ total_stats['accuracy'] += accuracy
171
+ num_batches += 1
172
+
173
+ return {k: v / num_batches for k, v in total_stats.items()}
174
+
175
+ def save_checkpoint(self, path: str):
176
+ """保存检查点"""
177
+ torch.save({
178
+ 'model_state_dict': self.reward_model.state_dict(),
179
+ 'optimizer_state_dict': self.optimizer.state_dict(),
180
+ }, path)
181
+
182
+ def load_checkpoint(self, path: str):
183
+ """加载检查点"""
184
+ checkpoint = torch.load(path, map_location=self.device)
185
+ self.reward_model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
186
  self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])