Spaces:
Runtime error
Runtime error
| import unittest | |
| import accelerate | |
| import pytest | |
| import torch | |
| import transformers | |
| import trlx.utils as utils | |
| import trlx.utils.modeling as modeling_utils | |
| try: | |
| import bitsandbytes | |
| HAS_BNB = True | |
| except ImportError: | |
| HAS_BNB = False | |
| # Test general utils | |
| def test_optimizer_class_getters(optimizer_name: str): | |
| try: | |
| _class = utils.get_optimizer_class(optimizer_name) | |
| except Exception as e: | |
| assert False, "Failed to get optimizer class with error: " + str(e) | |
| # Hard-check for one of the optimizers | |
| _class = utils.get_optimizer_class("adamw") | |
| assert _class == torch.optim.AdamW | |
| if HAS_BNB: | |
| _bnb_class = utils.get_optimizer_class("adamw_8bit_bnb") | |
| assert _bnb_class == bitsandbytes.optim.AdamW8bit | |
| def test_scheduler_class_getters(scheduler_name: str): | |
| try: | |
| _class = utils.get_scheduler_class(scheduler_name) | |
| except Exception as e: | |
| assert False, "Failed to get scheduler class with error: " + str(e) | |
| # Hard-check for one of the schedulers | |
| _class = utils.get_scheduler_class("cosine_annealing") | |
| assert _class == torch.optim.lr_scheduler.CosineAnnealingLR | |
| # Test modeling utils | |
| def test_hf_attr_getters(model_name: str): | |
| with accelerate.init_empty_weights(): | |
| config = transformers.AutoConfig.from_pretrained(model_name) | |
| arch = transformers.AutoModelForCausalLM.from_config(config) | |
| arch_getters = [ | |
| modeling_utils.hf_get_decoder, | |
| modeling_utils.hf_get_decoder_final_norm, | |
| modeling_utils.hf_get_decoder_blocks, | |
| modeling_utils.hf_get_lm_head, | |
| ] | |
| for get in arch_getters: | |
| try: | |
| get(arch) | |
| except Exception as e: | |
| assert False, "Failed to get model attribute with error: " + str(e) | |
| config_getters = [ | |
| modeling_utils.hf_get_hidden_size, | |
| modeling_utils.hf_get_num_hidden_layers, | |
| ] | |
| for get in config_getters: | |
| try: | |
| get(config) | |
| except Exception as e: | |
| assert False, "Failed to get config attribute with error: " + str(e) | |
| class TestStatistics(unittest.TestCase): | |
| def setUpClass(cls): | |
| cls.m = modeling_utils.RunningMoments() | |
| cls.a1 = torch.arange(100, dtype=float) | |
| cls.a2 = torch.ones(100, dtype=float) | |
| cls.a3 = torch.exp(torch.arange(10, dtype=float)) | |
| cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) | |
| def test_running_moments(self): | |
| assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) | |
| assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) | |
| assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) | |
| assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) | |
| a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) | |
| assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) | |
| assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) | |