Spaces:
Running on Zero
Running on Zero
| import torch | |
| import pytest | |
| from unittest.mock import MagicMock | |
| from src.NeuralNetwork.transformer import BasicTransformerBlock | |
| from src.Model.ModelPatcher import ModelPatcher | |
| def test_tome_forward_signature(): | |
| # Create the block | |
| block = BasicTransformerBlock(dim=64, n_heads=1, d_head=64) | |
| # Setup a mock model with the necessary structure for tomesd | |
| class MockDiffusionModel(torch.nn.Module): | |
| def __init__(self, block): | |
| super().__init__() | |
| self.transformer_block = block | |
| self.dtype = torch.float32 | |
| class MockModel(torch.nn.Module): | |
| def __init__(self, block): | |
| super().__init__() | |
| self.diffusion_model = MockDiffusionModel(block) | |
| mock_model = MockModel(block) | |
| # Use ModelPatcher to apply tome | |
| # ModelPatcher expects the model to have certain attributes | |
| patcher = ModelPatcher(mock_model, torch.device("cpu"), torch.device("cpu")) | |
| try: | |
| import tomesd | |
| except ImportError: | |
| pytest.skip("tomesd not installed") | |
| success = patcher.apply_tome(ratio=0.5) | |
| assert success, "Failed to apply ToMe" | |
| # The block's class should now be ToMeBlock | |
| assert block.__class__.__name__ == "ToMeBlock" | |
| # Now try to call the block | |
| x = torch.randn(1, 16, 64) | |
| transformer_options = {"some_option": True} | |
| # This should NOT raise TypeError: ToMeBlock._forward() takes from 2 to 3 positional arguments but 4 were given | |
| # Even if it fails later due to mock issues, the TypeError should be gone. | |
| try: | |
| # We need to mock compute_merge or ensure it has what it needs | |
| # tomesd.patch.hook_tome_model was called by apply_patch, which added _tome_info to diffusion_model | |
| # and ToMeBlock._forward uses self._tome_info (which is a reference to diffusion_model._tome_info) | |
| # Ensure _tome_info["size"] is set (normally set by hook) | |
| mock_model.diffusion_model._tome_info["size"] = (64, 64) | |
| block.forward(x, transformer_options=transformer_options) | |
| except TypeError as e: | |
| if "ToMeBlock._forward() takes" in str(e): | |
| pytest.fail(f"ToMe fix failed: {e}") | |
| # If it's another TypeError, it might be expected due to mocks | |
| raise e | |
| except Exception as e: | |
| # We don't care about other errors (like shape mismatches in mocks) | |
| # as long as the signature mismatch is fixed. | |
| print(f"Caught expected non-TypeError: {e}") | |
| if __name__ == "__main__": | |
| test_tome_forward_signature() | |