AutoMR-pangu / automr /trainer.py
haifei
Align AutoMR-pangu with MATH-only chat workflow
5d74de9
import random
import torch
from typing import List, Dict, Tuple
from tqdm import tqdm
import os
from .model import AutoMR
from .config import AutoMRConfig
from .utils import check_answer_match, ensure_dir, save_json
class AutoMRTrainer:
"""Trainer for AutoMR using REINFORCE (sync trainer, async model calls)"""
def __init__(self, model: AutoMR, config: AutoMRConfig):
self.model = model
self.config = config
ensure_dir(config.checkpoint_dir)
# Track training progress
self.global_step = 0
self.best_val_reward = -float('inf')
self.patience_counter = 0
# Sliding-window baseline for REINFORCE advantage (variance reduction)
self.baseline = self.config.initial_baseline
self.baseline_momentum = self.config.baseline_momentum
self.training_history = {
'train_loss': [],
'train_reward': [],
'val_reward': [],
'val_accuracy': [],
'steps': []
}
def compute_reward_batch(self, problems: List[str], answers: List[str]) -> Tuple[float, float]:
"""
Compute average reward and accuracy on a batch (Async)
Returns: (avg_reward, accuracy)
"""
total_reward = 0.0
correct = 0
total = len(problems)
self.model.strategy_mlp.eval()
self.model.strategy_embeddings.eval()
with torch.no_grad():
# M=1 for evaluation/validation; call async model via sync wrapper
pred_answers, _ = self.model.sample_batch_sync(problems, M=1)
for pred_answer, answer in zip(pred_answers, answers):
is_correct, _, _ = check_answer_match(pred_answer, answer, self.config.task_type)
if is_correct:
correct += 1
total_reward += 1.0
else:
total_reward += -1.0
avg_reward = total_reward / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
return avg_reward, accuracy
def validate(self, val_data: List[Dict[str, str]]) -> Tuple[float, float]:
"""
Run validation on validation set
Returns: (avg_reward, accuracy)
"""
# Sample validation batch
val_batch_size = min(self.config.val_batch_size, len(val_data))
val_batch = random.sample(val_data, val_batch_size)
val_problems = [item['problem'] for item in val_batch]
val_answers = [item['answer'] for item in val_batch]
avg_reward, accuracy = self.compute_reward_batch(val_problems, val_answers)
return avg_reward, accuracy
def train_step(self, batch_problems: List[str], batch_answers: List[str]) -> Tuple[float, float]:
"""
Single training step using REINFORCE
Returns: (loss, avg_reward)
"""
self.model.strategy_mlp.train()
self.model.strategy_embeddings.train()
M = self.config.num_samples_per_query
loss_list = []
rewards_list = []
# pred_answers: [B*M], log_probs: [B*M]; sync wrapper over async model
pred_answers, log_probs = self.model.sample_batch_sync(batch_problems, M)
# 2. Expand answers for comparison
expanded_answers = [answer for answer in batch_answers for _ in range(M)]
# 3. Compute Reward & Loss
for pred_answer, answer, log_prob in zip(pred_answers, expanded_answers, log_probs):
matched, _, _ = check_answer_match(
pred_answer, answer, self.config.task_type
)
reward = 1.0 if matched else -1.0
rewards_list.append(reward)
# Compute batch average reward
avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0
# Update sliding baseline: exponential moving average
self.baseline = (
self.baseline_momentum * self.baseline
+ (1.0 - self.baseline_momentum) * avg_reward
)
# Policy Gradient with advantage: -(reward - baseline) * log_prob
for reward, log_prob in zip(rewards_list, log_probs):
advantage = reward - self.baseline
loss_list.append(-advantage * log_prob)
# 4. Update parameters
self.model.optimizer.zero_grad()
loss = torch.stack(loss_list).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(
list(self.model.strategy_embeddings.parameters()) +
list(self.model.strategy_mlp.parameters()),
max_norm=self.config.gradient_clip
)
self.model.optimizer.step()
return loss.item(), avg_reward
def should_stop_early(self) -> bool:
"""Check if training should stop early"""
return self.patience_counter >= self.config.early_stopping_patience
def save_history(self):
history_path = os.path.join(self.config.checkpoint_dir, "training_history.json")
save_json(self.training_history, history_path)
def save_checkpoint(self, epoch: int, is_best: bool = False):
"""Save checkpoint"""
if self.config.save_best_only and not is_best:
return
checkpoint_name = f"checkpoint_epoch_{epoch}_step_{self.global_step}.pt"
if is_best:
checkpoint_name = "best_checkpoint.pt"
checkpoint_path = os.path.join(self.config.checkpoint_dir, checkpoint_name)
self.model.save_checkpoint(checkpoint_path)
# Also save training history
self.save_history()
if is_best:
print(f" 💾 Best checkpoint saved: {checkpoint_path}")
def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]):
"""Training loop with validation"""
print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...")
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Batch size: {self.config.batch_size}")
print(f"Samples per problem: {self.config.num_samples_per_query}")
print(f"Validation every {self.config.val_every_n_steps} steps")
print(f"Early stopping patience: {self.config.early_stopping_patience}\n")
for epoch in range(self.config.num_epochs):
random.shuffle(train_data)
epoch_loss = 0.0
epoch_reward = 0.0
num_batches = 0
batch_indices = list(range(0, len(train_data), self.config.batch_size))
pbar = tqdm(
batch_indices,
desc=f"Epoch {epoch+1}/{self.config.num_epochs}"
)
for i in pbar:
batch = train_data[i : i + self.config.batch_size]
batch_problems = [item['problem'] for item in batch]
batch_answers = [item['answer'] for item in batch]
# Training step (sync)
loss, avg_reward = self.train_step(batch_problems, batch_answers)
epoch_loss += loss
epoch_reward += avg_reward
num_batches += 1
self.global_step += 1
self.training_history['train_reward'].append(avg_reward)
self.save_history()
pbar.set_postfix({
'loss': f'{loss:.4f}',
'reward': f'{avg_reward:.3f}',
'step': self.global_step
})
# Validation
if self.global_step % self.config.val_every_n_steps == 0:
print(f"\n{'='*80}")
print(f"Validation at Step {self.global_step}")
print(f"{'='*80}")
val_reward, val_accuracy = self.validate(val_data)
print(f"Validation Reward: {val_reward:.4f}")
print(f"Validation Accuracy: {val_accuracy:.2%}")
# Record history
self.training_history['val_reward'].append(val_reward)
self.training_history['val_accuracy'].append(val_accuracy)
self.training_history['steps'].append(self.global_step)
# Check if this is the best model
is_best = val_reward > self.best_val_reward
if is_best:
print(f"✨ New best validation reward: {val_reward:.4f} (previous: {self.best_val_reward:.4f})")
self.best_val_reward = val_reward
self.patience_counter = 0
self.save_checkpoint(epoch + 1, is_best=True)
else:
self.patience_counter += 1
print(f"No improvement. Patience: {self.patience_counter}/{self.config.early_stopping_patience}")
print(f"{'='*80}\n")
# Check early stopping
if self.should_stop_early():
print(f"\n Early stopping triggered after {self.global_step} steps")
print(f"Best validation reward: {self.best_val_reward:.4f}")
return
# End of epoch summary
avg_epoch_loss = epoch_loss / max(num_batches, 1)
avg_epoch_reward = epoch_reward / max(num_batches, 1)
self.training_history['train_loss'].append(avg_epoch_loss)
self.training_history['train_reward'].append(avg_epoch_reward)
print(f"\n{'='*80}")
print(f"Epoch {epoch+1} Summary")
print(f"{'='*80}")
print(f"Average Loss: {avg_epoch_loss:.4f}")
print(f"Average Reward: {avg_epoch_reward:.4f}")
print(f"Best Val Reward: {self.best_val_reward:.4f}")
print(f"{'='*80}\n")
# Save checkpoint at end of epoch
if not self.config.save_best_only:
self.save_checkpoint(epoch + 1)
print("Training completed!")
print(f"Best validation reward achieved: {self.best_val_reward:.4f}")