MultiModal / reward_model.py
szxllm's picture
Update reward_model.py
5c5e75b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from collections import defaultdict
from typing import Dict, Tuple, Union, Optional
from tqdm import tqdm
from model import MultiModalDenseTransformer
class RewardModel(nn.Module):
"""奖励模型 - 用于RLHF"""
def __init__(
self,
base_model: MultiModalDenseTransformer,
use_value_head: bool = True
):
super().__init__()
self.base_model = base_model
self.use_value_head = use_value_head
self.reward_head = nn.Sequential(
nn.Linear(base_model.model_dim, base_model.model_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(base_model.model_dim // 2, 1)
)
if use_value_head:
self.value_head = nn.Sequential(
nn.Linear(base_model.model_dim, base_model.model_dim // 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(base_model.model_dim // 2, 1)
)
def forward(
self,
input_data: Dict,
return_values: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""前向传播"""
output = self.base_model(input_data, return_hidden=True)
hidden_states = output['last_hidden_state']
rewards = self.reward_head(hidden_states).squeeze(-1)
if return_values and self.use_value_head:
values = self.value_head(hidden_states).squeeze(-1)
return rewards, values
return rewards
class RewardModelTrainer:
"""奖励模型训练器"""
def __init__(
self,
reward_model: RewardModel,
learning_rate: float = 1e-5,
margin: float = 0.0
):
self.reward_model = reward_model
self.margin = margin
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.reward_model.to(self.device)
for param in self.reward_model.base_model.parameters():
param.requires_grad = False
for layer in self.reward_model.base_model.layers[-2:]:
for param in layer.parameters():
param.requires_grad = True
trainable_params = list(self.reward_model.reward_head.parameters())
if self.reward_model.use_value_head:
trainable_params += list(self.reward_model.value_head.parameters())
self.optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, self.reward_model.parameters()),
lr=learning_rate
)
def train_step(self, chosen_batch: Dict, rejected_batch: Dict) -> Dict:
"""单步训练"""
self.reward_model.train()
self.optimizer.zero_grad()
chosen_rewards = self.reward_model(chosen_batch)[:, -1]
rejected_rewards = self.reward_model(rejected_batch)[:, -1]
loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.reward_model.parameters(), 1.0)
self.optimizer.step()
accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
return {
'loss': loss.item(),
'accuracy': accuracy
}
def train(
self,
dataloader: DataLoader,
num_epochs: int = 1,
log_interval: int = 10
):
"""训练循环"""
print(f"Starting reward model training on {self.device}...")
for epoch in range(num_epochs):
total_stats = defaultdict(float)
num_steps = 0
progress_bar = tqdm(
dataloader,
desc=f"Reward Model Epoch {epoch+1}/{num_epochs}"
)
for batch_idx, (chosen_ids, rejected_ids) in enumerate(progress_bar):
chosen_batch = {
'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
}
rejected_batch = {
'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
}
stats = self.train_step(chosen_batch, rejected_batch)
for k, v in stats.items():
total_stats[k] += v
num_steps += 1
if (batch_idx + 1) % log_interval == 0:
avg_stats = {
k: v / num_steps
for k, v in total_stats.items()
}
progress_bar.set_postfix(avg_stats)
total_stats = defaultdict(float)
print("Reward model training complete!")
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""评估奖励模型"""
self.reward_model.eval()
total_stats = defaultdict(float)
num_batches = 0
with torch.no_grad():
for chosen_ids, rejected_ids in dataloader:
chosen_batch = {
'segments': [{'type': 'text', 'data': chosen_ids.to(self.device), 'modality_id': 0}]
}
rejected_batch = {
'segments': [{'type': 'text', 'data': rejected_ids.to(self.device), 'modality_id': 0}]
}
chosen_rewards = self.reward_model(chosen_batch)[:, -1]
rejected_rewards = self.reward_model(rejected_batch)[:, -1]
loss = -F.logsigmoid(chosen_rewards - rejected_rewards - self.margin).mean()
accuracy = (chosen_rewards > rejected_rewards).float().mean().item()
total_stats['loss'] += loss.item()
total_stats['accuracy'] += accuracy
num_batches += 1
return {k: v / num_batches for k, v in total_stats.items()}
def save_checkpoint(self, path: str):
"""保存检查点"""
torch.save({
'model_state_dict': self.reward_model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, path)
def load_checkpoint(self, path: str):
"""加载检查点"""
checkpoint = torch.load(path, map_location=self.device)
self.reward_model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])