TouchGrass-7b / tests /test_trainer.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tests for TouchGrass Trainer.
"""
import pytest
import torch
from unittest.mock import MagicMock, patch
from TouchGrass.training.trainer import TouchGrassTrainer
class TestTouchGrassTrainer:
"""Test suite for TouchGrassTrainer."""
def setup_method(self):
"""Set up test fixtures."""
self.device = "cpu"
self.d_model = 768
self.vocab_size = 32000
# Mock model
self.model = MagicMock()
self.model.parameters.return_value = [torch.randn(10, requires_grad=True)]
# Mock tokenizer
self.tokenizer = MagicMock()
self.tokenizer.pad_token_id = 0
# Mock loss function
self.loss_fn = MagicMock()
self.loss_fn.return_value = {"total_loss": torch.tensor(0.5)}
# Mock optimizer
self.optimizer = MagicMock()
self.optimizer.step = MagicMock()
self.optimizer.zero_grad = MagicMock()
# Mock scheduler
self.scheduler = MagicMock()
self.scheduler.step = MagicMock()
# Create trainer config
self.config = {
"batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 2e-4,
"max_grad_norm": 1.0,
"num_epochs": 1,
"save_steps": 100,
"eval_steps": 50,
"output_dir": "test_output"
}
def test_trainer_initialization(self):
"""Test trainer initialization."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
scheduler=self.scheduler,
config=self.config,
device=self.device
)
assert trainer.model == self.model
assert trainer.tokenizer == self.tokenizer
assert trainer.loss_fn == self.loss_fn
assert trainer.optimizer == self.optimizer
assert trainer.scheduler == self.scheduler
assert trainer.config == self.config
def test_trainer_required_components(self):
"""Test that all required components are present."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
assert hasattr(trainer, "train")
assert hasattr(trainer, "evaluate")
assert hasattr(trainer, "save_checkpoint")
assert hasattr(trainer, "load_checkpoint")
def test_prepare_batch(self):
"""Test batch preparation."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {
"input_ids": torch.randint(0, self.vocab_size, (4, 10)),
"attention_mask": torch.ones(4, 10),
"labels": torch.randint(0, self.vocab_size, (4, 10))
}
prepared = trainer._prepare_batch(batch)
assert "input_ids" in prepared
assert "attention_mask" in prepared
assert "labels" in prepared
def test_training_step(self):
"""Test single training step."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {
"input_ids": torch.randint(0, self.vocab_size, (4, 10)),
"attention_mask": torch.ones(4, 10),
"labels": torch.randint(0, self.vocab_size, (4, 10))
}
loss = trainer._training_step(batch)
assert isinstance(loss, torch.Tensor) or loss is not None
def test_evaluation_step(self):
"""Test single evaluation step."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {
"input_ids": torch.randint(0, self.vocab_size, (4, 10)),
"attention_mask": torch.ones(4, 10),
"labels": torch.randint(0, self.vocab_size, (4, 10))
}
metrics = trainer._evaluation_step(batch)
assert isinstance(metrics, dict)
def test_gradient_accumulation(self):
"""Test gradient accumulation."""
config = self.config.copy()
config["gradient_accumulation_steps"] = 2
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=config,
device=self.device
)
assert trainer.gradient_accumulation_steps == 2
def test_checkpoint_saving(self, tmp_path):
"""Test checkpoint saving."""
config = self.config.copy()
config["output_dir"] = str(tmp_path / "checkpoints")
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=config,
device=self.device
)
trainer.save_checkpoint(step=100)
# Should create checkpoint files
# (actual file creation would depend on implementation)
def test_learning_rate_scheduler_step(self):
"""Test that scheduler is stepped correctly."""
config = self.config.copy()
config["learning_rate"] = 1e-3
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
scheduler=self.scheduler,
config=config,
device=self.device
)
# After training step, scheduler should be called
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._training_step(batch)
# Scheduler step should be called (depending on implementation)
# This is a simple check - actual behavior may vary
def test_gradient_clipping(self):
"""Test gradient clipping."""
config = self.config.copy()
config["max_grad_norm"] = 1.0
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=config,
device=self.device
)
assert trainer.max_grad_norm == 1.0
def test_mixed_precision_flag(self):
"""Test mixed precision training flag."""
config = self.config.copy()
config["mixed_precision"] = True
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=config,
device=self.device
)
assert trainer.mixed_precision is True
def test_device_assignment(self):
"""Test that model and data are moved to correct device."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device="cpu"
)
assert trainer.device == "cpu"
def test_optimizer_zero_grad_called(self):
"""Test that optimizer.zero_grad is called."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._training_step(batch)
self.optimizer.zero_grad.assert_called()
def test_optimizer_step_called(self):
"""Test that optimizer.step is called."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._training_step(batch)
self.optimizer.step.assert_called()
def test_loss_fn_called_with_outputs(self):
"""Test that loss function is called with model outputs."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._training_step(batch)
# Loss function should be called
self.loss_fn.assert_called()
def test_training_loop(self):
"""Test full training loop (simplified)."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
# Mock dataloader
train_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
# Run a single epoch (with mocked data)
metrics = trainer.train(train_dataloader, eval_dataloader)
assert isinstance(metrics, dict)
def test_evaluation_loop(self):
"""Test evaluation loop."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
eval_dataloader = [{"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}]
metrics = trainer.evaluate(eval_dataloader)
assert isinstance(metrics, dict)
def test_config_validation(self):
"""Test that config has required keys."""
required_keys = ["batch_size", "learning_rate", "num_epochs", "output_dir"]
for key in required_keys:
config = self.config.copy()
del config[key]
with pytest.raises(ValueError, match=key):
TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=config,
device=self.device
)
def test_model_mode_training(self):
"""Test that model is set to training mode."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._training_step(batch)
self.model.train.assert_called()
def test_model_mode_evaluation(self):
"""Test that model is set to eval mode during evaluation."""
trainer = TouchGrassTrainer(
model=self.model,
tokenizer=self.tokenizer,
loss_fn=self.loss_fn,
optimizer=self.optimizer,
config=self.config,
device=self.device
)
batch = {"input_ids": torch.randint(0, self.vocab_size, (4, 10)), "attention_mask": torch.ones(4, 10), "labels": torch.randint(0, self.vocab_size, (4, 10))}
trainer._evaluation_step(batch)
self.model.eval.assert_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])