| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from collections import defaultdict | |
| from typing import Dict, Tuple, Union, Optional | |
| from tqdm import tqdm | |
| from model import MultiModalDenseTransformer | |
| class RewardModel(nn.Module): | |
| """奖励模型 - 用于RLHF""" | |
| def __init__( | |
| self, | |
| base_model: MultiModalDenseTransformer, | |
| use_value_head: bool = True | |
| ): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.use_value_head = use_value_head | |
| self.reward_head = nn.Sequential( | |
| nn.Linear(base_model.model_dim, base_model.model_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(base_model.model_dim // 2, 1) | |
| ) | |
| if use_value_head: | |
| self.value_head = nn.Sequential( | |
| nn.Linear(base_model.model_dim, base_model.model_dim // 2), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(base_model.model_dim // 2, 1) | |
| ) | |
| def forward( | |
| self, | |
| input_data: Dict, | |
| return_values: bool = False | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| """前向传播""" | |
| output = self.base_model(input_data, return_hidden=True) | |
| hidden_states = output['last_hidden_state'] | |
| rewards = self.reward_head(hidden_states).squeeze(-1) | |
| if return_values and self.use_value_head: | |
| values = self.value_head(hidden_states).squeeze(-1) | |
| return rewards, values | |
| return rewards | |
| class RewardModelTrainer: | |
| """奖励模型训练器""" | |
| def __init__( | |
| self, | |
| reward_model: RewardModel, | |
| learning_rate: float = 1e-5, | |
| margin: float = 0.0 | |
| ): | |
| self.reward_model = reward_model | |
| self.margin = margin | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.reward_model.to(self.device) | |
| for param in self.reward_model.base_model.parameters(): | |
| param.requires_grad = False | |
| for layer in self.reward_model.base_model.layers[-2:]: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| trainable_params = list(self.reward_model.reward_head.parameters()) | |
| if self.reward_model.use_value_head: | |
| trainable_params += list(self.reward_model.value_head.parameters()) | |
| self.optimizer = optim.AdamW( | |
| filter(lambda p: p.requires_grad, self.reward_model.parameters()), | |
| lr=learning_rate | |
| ) | |
| def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict: | |
| """单步训练""" | |
| self.reward_model.train() | |
| self.optimizer.zero_grad() | |
| chosen_rewards = self.reward_model(chosen_batch)[:, -1] | |
| rejected_rewards = self.reward_model(rejected_batch)[:, -1] | |
| loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0) | |
| self.optimizer.step() | |
| accuracy = (chosen_rewards > rejected_rewards).float().mean().item() | |
| return { | |
| 'loss': loss.item(), | |
| 'accuracy': accuracy | |
| } | |
| def train( | |
| self, | |
| dataloader: DataLoader, | |
| num_epochs: int = 1, | |
| log_interval: int = 10 | |
| ): | |
| """训练循环""" | |
| print(f"Starting reward model training on {self.device}...") | |
| for epoch in range(num_epochs): | |
| total_stats = defaultdict(float) | |
| num_steps = 0 | |
| progress_bar = tqdm( | |
| dataloader, | |
| desc=f"Reward Model Epoch {epoch+1}/{num_epochs}" | |
| ) | |
| for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar): | |
| chosen_batch = { | |
| 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}] | |
| } | |
| rejected_batch = { | |
| 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}] | |
| } | |
| stats = self.train_step(chosen_batch, rejected_batch) | |
| for k, v in stats.items(): | |
| total_stats[k] += v | |
| num_steps += 1 | |
| if (batch_idx + 1) % log_interval == 0: | |
| avg_stats = { | |
| k: v / num_steps | |
| for k, v in total_stats.items() | |
| } | |
| progress_bar.set_postfix(avg_stats) | |
| total_stats = defaultdict(float) | |
| print("Reward model training complete!") | |
| def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: | |
| """评估奖励模型""" | |
| self.reward_model.eval() | |
| total_stats = defaultdict(float) | |
| num_batches = 0 | |
| with torch.no_grad(): | |
| for chosen_ids, rejected_ids in dataloader: | |
| chosen_batch = { | |
| 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}] | |
| } | |
| rejected_batch = { | |
| 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}] | |
| } | |
| chosen_rewards = self.reward_model(chosen_batch)[:, -1] | |
| rejected_rewards = self.reward_model(rejected_batch)[:, -1] | |
| loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean() | |
| accuracy = (chosen_rewards > rejected_rewards).float().mean().item() | |
| total_stats['loss'] += loss.item() | |
| total_stats['accuracy'] += accuracy | |
| num_batches += 1 | |
| return {k: v / num_batches for k, v in total_stats.items()} | |
| def save_checkpoint(self, path: str): | |
| """保存检查点""" | |
| torch.save({ | |
| 'model_state_dict': self.reward_model.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| }, path) | |
| def load_checkpoint(self, path: str): | |
| """加载检查点""" | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.reward_model.load_state_dict(checkpoint['model_state_dict']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |