""" Tests for TouchGrass Trainer. """ import pytest import torch from unittest.mock import MagicMock, patch from TouchGrass.training.trainer import TouchGrassTrainer class TestTouchGrassTrainer: """Test suite for TouchGrassTrainer.""" def setup_method(self): """Set up test fixtures.""" self.device = "cpu" self.d_model = 768 self.vocab_size = 32000 # Mock model self.model = MagicMock() self.model.parameters.return_value = [torch.randn(10, requires_grad=True)] # Mock tokenizer self.tokenizer = MagicMock() self.tokenizer.pad_token_id = 0 # Mock loss function self.loss_fn = MagicMock() self.loss_fn.return_value = {"total_loss": torch.tensor(0.5)} # Mock optimizer self.optimizer = MagicMock() self.optimizer.step = MagicMock() self.optimizer.zero_grad = MagicMock() # Mock scheduler self.scheduler = MagicMock() self.scheduler.step = MagicMock() # Create trainer config self.config = { "batch_size": 4, "gradient_accumulation_steps": 1, "learning_rate": 2e-4, "max_grad_norm": 1.0, "num_epochs": 1, "save_steps": 100, "eval_steps": 50, "output_dir": "test_output" } def test_trainer_initialization(self): """Test trainer initialization.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, scheduler=self.scheduler, config=self.config, device=self.device ) assert trainer.model == self.model assert trainer.tokenizer == self.tokenizer assert trainer.loss_fn == self.loss_fn assert trainer.optimizer == self.optimizer assert trainer.scheduler == self.scheduler assert trainer.config == self.config def test_trainer_required_components(self): """Test that all required components are present.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) assert hasattr(trainer, "train") assert hasattr(trainer, "evaluate") assert hasattr(trainer, "save_checkpoint") assert hasattr(trainer, "load_checkpoint") def test_prepare_batch(self): """Test batch preparation.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = { "input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10)) } prepared = trainer._prepare_batch(batch) assert "input_ids" in prepared assert "attention_mask" in prepared assert "labels" in prepared def test_training_step(self): """Test single training step.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = { "input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10)) } loss = trainer._training_step(batch) assert isinstance(loss, torch.Tensor) or loss is not None def test_evaluation_step(self): """Test single evaluation step.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = { "input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10)) } metrics = trainer._evaluation_step(batch) assert isinstance(metrics, dict) def test_gradient_accumulation(self): """Test gradient accumulation.""" config = self.config.copy() config["gradient_accumulation_steps"] = 2 trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=config, device=self.device ) assert trainer.gradient_accumulation_steps == 2 def test_checkpoint_saving(self, tmp_path): """Test checkpoint saving.""" config = self.config.copy() config["output_dir"] = str(tmp_path / "checkpoints") trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=config, device=self.device ) trainer.save_checkpoint(step=100) # Should create checkpoint files # (actual file creation would depend on implementation) def test_learning_rate_scheduler_step(self): """Test that scheduler is stepped correctly.""" config = self.config.copy() config["learning_rate"] = 1e-3 trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, scheduler=self.scheduler, config=config, device=self.device ) # After training step, scheduler should be called batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._training_step(batch) # Scheduler step should be called (depending on implementation) # This is a simple check - actual behavior may vary def test_gradient_clipping(self): """Test gradient clipping.""" config = self.config.copy() config["max_grad_norm"] = 1.0 trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=config, device=self.device ) assert trainer.max_grad_norm == 1.0 def test_mixed_precision_flag(self): """Test mixed precision training flag.""" config = self.config.copy() config["mixed_precision"] = True trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=config, device=self.device ) assert trainer.mixed_precision is True def test_device_assignment(self): """Test that model and data are moved to correct device.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device="cpu" ) assert trainer.device == "cpu" def test_optimizer_zero_grad_called(self): """Test that optimizer.zero_grad is called.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._training_step(batch) self.optimizer.zero_grad.assert_called() def test_optimizer_step_called(self): """Test that optimizer.step is called.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._training_step(batch) self.optimizer.step.assert_called() def test_loss_fn_called_with_outputs(self): """Test that loss function is called with model outputs.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._training_step(batch) # Loss function should be called self.loss_fn.assert_called() def test_training_loop(self): """Test full training loop (simplified).""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) # Mock dataloader train_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}] eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}] # Run a single epoch (with mocked data) metrics = trainer.train(train_dataloader, eval_dataloader) assert isinstance(metrics, dict) def test_evaluation_loop(self): """Test evaluation loop.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}] metrics = trainer.evaluate(eval_dataloader) assert isinstance(metrics, dict) def test_config_validation(self): """Test that config has required keys.""" required_keys = ["batch_size", "learning_rate", "num_epochs", "output_dir"] for key in required_keys: config = self.config.copy() del config[key] with pytest.raises(ValueError, match=key): TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=config, device=self.device ) def test_model_mode_training(self): """Test that model is set to training mode.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._training_step(batch) self.model.train.assert_called() def test_model_mode_evaluation(self): """Test that model is set to eval mode during evaluation.""" trainer = TouchGrassTrainer( model=self.model, tokenizer=self.tokenizer, loss_fn=self.loss_fn, optimizer=self.optimizer, config=self.config, device=self.device ) batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))} trainer._evaluation_step(batch) self.model.eval.assert_called() if __name__ == "__main__": pytest.main([__file__, "-v"])