import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from collections import defaultdict from typing import Dict, Tuple, Union, Optional from tqdm import tqdm from model import MultiModalDenseTransformer class RewardModel(nn.Module): """奖励模型 - 用于RLHF""" def __init__( self, base_model: MultiModalDenseTransformer, use_value_head: bool = True ): super().__init__() self.base_model = base_model self.use_value_head = use_value_head self.reward_head = nn.Sequential( nn.Linear(base_model.model_dim, base_model.model_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(base_model.model_dim // 2, 1) ) if use_value_head: self.value_head = nn.Sequential( nn.Linear(base_model.model_dim, base_model.model_dim // 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(base_model.model_dim // 2, 1) ) def forward( self, input_data: Dict, return_values: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """前向传播""" output = self.base_model(input_data, return_hidden=True) hidden_states = output['last_hidden_state'] rewards = self.reward_head(hidden_states).squeeze(-1) if return_values and self.use_value_head: values = self.value_head(hidden_states).squeeze(-1) return rewards, values return rewards class RewardModelTrainer: """奖励模型训练器""" def __init__( self, reward_model: RewardModel, learning_rate: float = 1e-5, margin: float = 0.0 ): self.reward_model = reward_model self.margin = margin self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.reward_model.to(self.device) for param in self.reward_model.base_model.parameters(): param.requires_grad = False for layer in self.reward_model.base_model.layers[-2:]: for param in layer.parameters(): param.requires_grad = True trainable_params = list(self.reward_model.reward_head.parameters()) if self.reward_model.use_value_head: trainable_params += list(self.reward_model.value_head.parameters()) self.optimizer = optim.AdamW( filter(lambda p: p.requires_grad, self.reward_model.parameters()), lr=learning_rate ) def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict: """单步训练""" self.reward_model.train() self.optimizer.zero_grad() chosen_rewards = self.reward_model(chosen_batch)[:, -1] rejected_rewards = self.reward_model(rejected_batch)[:, -1] loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean() loss.backward() torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0) self.optimizer.step() accuracy = (chosen_rewards > rejected_rewards).float().mean().item() return { 'loss': loss.item(), 'accuracy': accuracy } def train( self, dataloader: DataLoader, num_epochs: int = 1, log_interval: int = 10 ): """训练循环""" print(f"Starting reward model training on {self.device}...") for epoch in range(num_epochs): total_stats = defaultdict(float) num_steps = 0 progress_bar = tqdm( dataloader, desc=f"Reward Model Epoch {epoch+1}/{num_epochs}" ) for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar): chosen_batch = { 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}] } rejected_batch = { 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}] } stats = self.train_step(chosen_batch, rejected_batch) for k, v in stats.items(): total_stats[k] += v num_steps += 1 if (batch_idx + 1) % log_interval == 0: avg_stats = { k: v / num_steps for k, v in total_stats.items() } progress_bar.set_postfix(avg_stats) total_stats = defaultdict(float) print("Reward model training complete!") def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: """评估奖励模型""" self.reward_model.eval() total_stats = defaultdict(float) num_batches = 0 with torch.no_grad(): for chosen_ids, rejected_ids in dataloader: chosen_batch = { 'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}] } rejected_batch = { 'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}] } chosen_rewards = self.reward_model(chosen_batch)[:, -1] rejected_rewards = self.reward_model(rejected_batch)[:, -1] loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean() accuracy = (chosen_rewards > rejected_rewards).float().mean().item() total_stats['loss'] += loss.item() total_stats['accuracy'] += accuracy num_batches += 1 return {k: v / num_batches for k, v in total_stats.items()} def save_checkpoint(self, path: str): """保存检查点""" torch.save({ 'model_state_dict': self.reward_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), }, path) def load_checkpoint(self, path: str): """加载检查点""" checkpoint = torch.load(path, map_location=self.device) self.reward_model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])