| | """
|
| | Tests for TouchGrass Trainer.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | from unittest.mock import MagicMock, patch
|
| |
|
| | from TouchGrass.training.trainer import TouchGrassTrainer
|
| |
|
| |
|
| | class TestTouchGrassTrainer:
|
| | """Test suite for TouchGrassTrainer."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.device = "cpu"
|
| | self.d_model = 768
|
| | self.vocab_size = 32000
|
| |
|
| |
|
| | self.model = MagicMock()
|
| | self.model.parameters.return_value = [torch.randn(10, requires_grad=True)]
|
| |
|
| |
|
| | self.tokenizer = MagicMock()
|
| | self.tokenizer.pad_token_id = 0
|
| |
|
| |
|
| | self.loss_fn = MagicMock()
|
| | self.loss_fn.return_value = {"total_loss": torch.tensor(0.5)}
|
| |
|
| |
|
| | self.optimizer = MagicMock()
|
| | self.optimizer.step = MagicMock()
|
| | self.optimizer.zero_grad = MagicMock()
|
| |
|
| |
|
| | self.scheduler = MagicMock()
|
| | self.scheduler.step = MagicMock()
|
| |
|
| |
|
| | self.config = {
|
| | "batch_size": 4,
|
| | "gradient_accumulation_steps": 1,
|
| | "learning_rate": 2e-4,
|
| | "max_grad_norm": 1.0,
|
| | "num_epochs": 1,
|
| | "save_steps": 100,
|
| | "eval_steps": 50,
|
| | "output_dir": "test_output"
|
| | }
|
| |
|
| | def test_trainer_initialization(self):
|
| | """Test trainer initialization."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | scheduler=self.scheduler,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | assert trainer.model == self.model
|
| | assert trainer.tokenizer == self.tokenizer
|
| | assert trainer.loss_fn == self.loss_fn
|
| | assert trainer.optimizer == self.optimizer
|
| | assert trainer.scheduler == self.scheduler
|
| | assert trainer.config == self.config
|
| |
|
| | def test_trainer_required_components(self):
|
| | """Test that all required components are present."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | assert hasattr(trainer, "train")
|
| | assert hasattr(trainer, "evaluate")
|
| | assert hasattr(trainer, "save_checkpoint")
|
| | assert hasattr(trainer, "load_checkpoint")
|
| |
|
| | def test_prepare_batch(self):
|
| | """Test batch preparation."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {
|
| | "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| | "attention_mask": torch.ones(4, 10),
|
| | "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| | }
|
| |
|
| | prepared = trainer._prepare_batch(batch)
|
| | assert "input_ids" in prepared
|
| | assert "attention_mask" in prepared
|
| | assert "labels" in prepared
|
| |
|
| | def test_training_step(self):
|
| | """Test single training step."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {
|
| | "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| | "attention_mask": torch.ones(4, 10),
|
| | "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| | }
|
| |
|
| | loss = trainer._training_step(batch)
|
| | assert isinstance(loss, torch.Tensor) or loss is not None
|
| |
|
| | def test_evaluation_step(self):
|
| | """Test single evaluation step."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {
|
| | "input_ids": torch.randint(0, self.vocab_size, (4, 10)),
|
| | "attention_mask": torch.ones(4, 10),
|
| | "labels": torch.randint(0, self.vocab_size, (4, 10))
|
| | }
|
| |
|
| | metrics = trainer._evaluation_step(batch)
|
| | assert isinstance(metrics, dict)
|
| |
|
| | def test_gradient_accumulation(self):
|
| | """Test gradient accumulation."""
|
| | config = self.config.copy()
|
| | config["gradient_accumulation_steps"] = 2
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| | assert trainer.gradient_accumulation_steps == 2
|
| |
|
| | def test_checkpoint_saving(self, tmp_path):
|
| | """Test checkpoint saving."""
|
| | config = self.config.copy()
|
| | config["output_dir"] = str(tmp_path / "checkpoints")
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| | trainer.save_checkpoint(step=100)
|
| |
|
| |
|
| |
|
| | def test_learning_rate_scheduler_step(self):
|
| | """Test that scheduler is stepped correctly."""
|
| | config = self.config.copy()
|
| | config["learning_rate"] = 1e-3
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | scheduler=self.scheduler,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._training_step(batch)
|
| |
|
| |
|
| |
|
| |
|
| | def test_gradient_clipping(self):
|
| | """Test gradient clipping."""
|
| | config = self.config.copy()
|
| | config["max_grad_norm"] = 1.0
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| | assert trainer.max_grad_norm == 1.0
|
| |
|
| | def test_mixed_precision_flag(self):
|
| | """Test mixed precision training flag."""
|
| | config = self.config.copy()
|
| | config["mixed_precision"] = True
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| | assert trainer.mixed_precision is True
|
| |
|
| | def test_device_assignment(self):
|
| | """Test that model and data are moved to correct device."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device="cpu"
|
| | )
|
| |
|
| | assert trainer.device == "cpu"
|
| |
|
| | def test_optimizer_zero_grad_called(self):
|
| | """Test that optimizer.zero_grad is called."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._training_step(batch)
|
| |
|
| | self.optimizer.zero_grad.assert_called()
|
| |
|
| | def test_optimizer_step_called(self):
|
| | """Test that optimizer.step is called."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._training_step(batch)
|
| |
|
| | self.optimizer.step.assert_called()
|
| |
|
| | def test_loss_fn_called_with_outputs(self):
|
| | """Test that loss function is called with model outputs."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._training_step(batch)
|
| |
|
| |
|
| | self.loss_fn.assert_called()
|
| |
|
| | def test_training_loop(self):
|
| | """Test full training loop (simplified)."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| |
|
| | train_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
| | eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
| |
|
| |
|
| | metrics = trainer.train(train_dataloader, eval_dataloader)
|
| | assert isinstance(metrics, dict)
|
| |
|
| | def test_evaluation_loop(self):
|
| | """Test evaluation loop."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
|
| |
|
| | metrics = trainer.evaluate(eval_dataloader)
|
| | assert isinstance(metrics, dict)
|
| |
|
| | def test_config_validation(self):
|
| | """Test that config has required keys."""
|
| | required_keys = ["batch_size", "learning_rate", "num_epochs", "output_dir"]
|
| |
|
| | for key in required_keys:
|
| | config = self.config.copy()
|
| | del config[key]
|
| | with pytest.raises(ValueError, match=key):
|
| | TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=config,
|
| | device=self.device
|
| | )
|
| |
|
| | def test_model_mode_training(self):
|
| | """Test that model is set to training mode."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._training_step(batch)
|
| |
|
| | self.model.train.assert_called()
|
| |
|
| | def test_model_mode_evaluation(self):
|
| | """Test that model is set to eval mode during evaluation."""
|
| | trainer = TouchGrassTrainer(
|
| | model=self.model,
|
| | tokenizer=self.tokenizer,
|
| | loss_fn=self.loss_fn,
|
| | optimizer=self.optimizer,
|
| | config=self.config,
|
| | device=self.device
|
| | )
|
| |
|
| | batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
|
| | trainer._evaluation_step(batch)
|
| |
|
| | self.model.eval.assert_called()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|