| import random |
| import torch |
| from typing import List, Dict, Tuple |
| from tqdm import tqdm |
| import os |
|
|
| from .model import AutoMR |
| from .config import AutoMRConfig |
| from .utils import check_answer_match, ensure_dir, save_json |
|
|
| class AutoMRTrainer: |
| """Trainer for AutoMR using REINFORCE (sync trainer, async model calls)""" |
| |
| def __init__(self, model: AutoMR, config: AutoMRConfig): |
| self.model = model |
| self.config = config |
| ensure_dir(config.checkpoint_dir) |
| |
| |
| self.global_step = 0 |
| self.best_val_reward = -float('inf') |
| self.patience_counter = 0 |
| |
| self.baseline = self.config.initial_baseline |
| self.baseline_momentum = self.config.baseline_momentum |
| self.training_history = { |
| 'train_loss': [], |
| 'train_reward': [], |
| 'val_reward': [], |
| 'val_accuracy': [], |
| 'steps': [] |
| } |
| |
| def compute_reward_batch(self, problems: List[str], answers: List[str]) -> Tuple[float, float]: |
| """ |
| Compute average reward and accuracy on a batch (Async) |
| Returns: (avg_reward, accuracy) |
| """ |
| total_reward = 0.0 |
| correct = 0 |
| total = len(problems) |
| |
| self.model.strategy_mlp.eval() |
| self.model.strategy_embeddings.eval() |
| |
| with torch.no_grad(): |
| |
| pred_answers, _ = self.model.sample_batch_sync(problems, M=1) |
|
|
| for pred_answer, answer in zip(pred_answers, answers): |
| is_correct, _, _ = check_answer_match(pred_answer, answer, self.config.task_type) |
| if is_correct: |
| correct += 1 |
| total_reward += 1.0 |
| else: |
| total_reward += -1.0 |
|
|
| avg_reward = total_reward / total if total > 0 else 0.0 |
| accuracy = correct / total if total > 0 else 0.0 |
| |
| return avg_reward, accuracy |
| |
| def validate(self, val_data: List[Dict[str, str]]) -> Tuple[float, float]: |
| """ |
| Run validation on validation set |
| Returns: (avg_reward, accuracy) |
| """ |
| |
| val_batch_size = min(self.config.val_batch_size, len(val_data)) |
| val_batch = random.sample(val_data, val_batch_size) |
| |
| val_problems = [item['problem'] for item in val_batch] |
| val_answers = [item['answer'] for item in val_batch] |
| |
| avg_reward, accuracy = self.compute_reward_batch(val_problems, val_answers) |
| |
| return avg_reward, accuracy |
| |
| def train_step(self, batch_problems: List[str], batch_answers: List[str]) -> Tuple[float, float]: |
| """ |
| Single training step using REINFORCE |
| Returns: (loss, avg_reward) |
| """ |
| self.model.strategy_mlp.train() |
| self.model.strategy_embeddings.train() |
| |
| M = self.config.num_samples_per_query |
| loss_list = [] |
| rewards_list = [] |
| |
| |
| pred_answers, log_probs = self.model.sample_batch_sync(batch_problems, M) |
| |
| |
| expanded_answers = [answer for answer in batch_answers for _ in range(M)] |
| |
| |
| for pred_answer, answer, log_prob in zip(pred_answers, expanded_answers, log_probs): |
| matched, _, _ = check_answer_match( |
| pred_answer, answer, self.config.task_type |
| ) |
| reward = 1.0 if matched else -1.0 |
| |
| rewards_list.append(reward) |
| |
| |
| avg_reward = sum(rewards_list) / len(rewards_list) if rewards_list else 0.0 |
| |
| |
| self.baseline = ( |
| self.baseline_momentum * self.baseline |
| + (1.0 - self.baseline_momentum) * avg_reward |
| ) |
| |
| |
| for reward, log_prob in zip(rewards_list, log_probs): |
| advantage = reward - self.baseline |
| loss_list.append(-advantage * log_prob) |
| |
| |
| self.model.optimizer.zero_grad() |
| |
| loss = torch.stack(loss_list).mean() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_( |
| list(self.model.strategy_embeddings.parameters()) + |
| list(self.model.strategy_mlp.parameters()), |
| max_norm=self.config.gradient_clip |
| ) |
| self.model.optimizer.step() |
| return loss.item(), avg_reward |
| |
| def should_stop_early(self) -> bool: |
| """Check if training should stop early""" |
| return self.patience_counter >= self.config.early_stopping_patience |
| |
| def save_history(self): |
| history_path = os.path.join(self.config.checkpoint_dir, "training_history.json") |
| save_json(self.training_history, history_path) |
| |
| def save_checkpoint(self, epoch: int, is_best: bool = False): |
| """Save checkpoint""" |
| if self.config.save_best_only and not is_best: |
| return |
| |
| checkpoint_name = f"checkpoint_epoch_{epoch}_step_{self.global_step}.pt" |
| if is_best: |
| checkpoint_name = "best_checkpoint.pt" |
| |
| checkpoint_path = os.path.join(self.config.checkpoint_dir, checkpoint_name) |
| |
| self.model.save_checkpoint(checkpoint_path) |
| |
| |
| self.save_history() |
| |
| if is_best: |
| print(f" 💾 Best checkpoint saved: {checkpoint_path}") |
| |
| def train(self, train_data: List[Dict[str, str]], val_data: List[Dict[str, str]]): |
| """Training loop with validation""" |
| print(f"\nStarting AutoMR training for {self.config.num_epochs} epochs...") |
| print(f"Training samples: {len(train_data)}") |
| print(f"Validation samples: {len(val_data)}") |
| print(f"Batch size: {self.config.batch_size}") |
| print(f"Samples per problem: {self.config.num_samples_per_query}") |
| print(f"Validation every {self.config.val_every_n_steps} steps") |
| print(f"Early stopping patience: {self.config.early_stopping_patience}\n") |
| |
| for epoch in range(self.config.num_epochs): |
| random.shuffle(train_data) |
| epoch_loss = 0.0 |
| epoch_reward = 0.0 |
| num_batches = 0 |
| |
| batch_indices = list(range(0, len(train_data), self.config.batch_size)) |
| |
| pbar = tqdm( |
| batch_indices, |
| desc=f"Epoch {epoch+1}/{self.config.num_epochs}" |
| ) |
| |
| for i in pbar: |
| batch = train_data[i : i + self.config.batch_size] |
| batch_problems = [item['problem'] for item in batch] |
| batch_answers = [item['answer'] for item in batch] |
| |
| |
| loss, avg_reward = self.train_step(batch_problems, batch_answers) |
| |
| epoch_loss += loss |
| epoch_reward += avg_reward |
| num_batches += 1 |
| self.global_step += 1 |
| |
| self.training_history['train_reward'].append(avg_reward) |
| self.save_history() |
|
|
| pbar.set_postfix({ |
| 'loss': f'{loss:.4f}', |
| 'reward': f'{avg_reward:.3f}', |
| 'step': self.global_step |
| }) |
| |
| |
| if self.global_step % self.config.val_every_n_steps == 0: |
| print(f"\n{'='*80}") |
| print(f"Validation at Step {self.global_step}") |
| print(f"{'='*80}") |
| |
| val_reward, val_accuracy = self.validate(val_data) |
| |
| print(f"Validation Reward: {val_reward:.4f}") |
| print(f"Validation Accuracy: {val_accuracy:.2%}") |
| |
| |
| self.training_history['val_reward'].append(val_reward) |
| self.training_history['val_accuracy'].append(val_accuracy) |
| self.training_history['steps'].append(self.global_step) |
| |
| |
| is_best = val_reward > self.best_val_reward |
| if is_best: |
| print(f"✨ New best validation reward: {val_reward:.4f} (previous: {self.best_val_reward:.4f})") |
| self.best_val_reward = val_reward |
| self.patience_counter = 0 |
| self.save_checkpoint(epoch + 1, is_best=True) |
| else: |
| self.patience_counter += 1 |
| print(f"No improvement. Patience: {self.patience_counter}/{self.config.early_stopping_patience}") |
| |
| print(f"{'='*80}\n") |
| |
| |
| if self.should_stop_early(): |
| print(f"\n Early stopping triggered after {self.global_step} steps") |
| print(f"Best validation reward: {self.best_val_reward:.4f}") |
| return |
| |
| |
| avg_epoch_loss = epoch_loss / max(num_batches, 1) |
| avg_epoch_reward = epoch_reward / max(num_batches, 1) |
| |
| self.training_history['train_loss'].append(avg_epoch_loss) |
| self.training_history['train_reward'].append(avg_epoch_reward) |
| |
| print(f"\n{'='*80}") |
| print(f"Epoch {epoch+1} Summary") |
| print(f"{'='*80}") |
| print(f"Average Loss: {avg_epoch_loss:.4f}") |
| print(f"Average Reward: {avg_epoch_reward:.4f}") |
| print(f"Best Val Reward: {self.best_val_reward:.4f}") |
| print(f"{'='*80}\n") |
| |
| |
| if not self.config.save_best_only: |
| self.save_checkpoint(epoch + 1) |
| |
| print("Training completed!") |
| print(f"Best validation reward achieved: {self.best_val_reward:.4f}") |
|
|