Spaces:
Running on Zero
Running on Zero
File size: 2,569 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import pytest
from unittest.mock import MagicMock
from src.NeuralNetwork.transformer import BasicTransformerBlock
from src.Model.ModelPatcher import ModelPatcher
def test_tome_forward_signature():
# Create the block
block = BasicTransformerBlock(dim=64, n_heads=1, d_head=64)
# Setup a mock model with the necessary structure for tomesd
class MockDiffusionModel(torch.nn.Module):
def __init__(self, block):
super().__init__()
self.transformer_block = block
self.dtype = torch.float32
class MockModel(torch.nn.Module):
def __init__(self, block):
super().__init__()
self.diffusion_model = MockDiffusionModel(block)
mock_model = MockModel(block)
# Use ModelPatcher to apply tome
# ModelPatcher expects the model to have certain attributes
patcher = ModelPatcher(mock_model, torch.device("cpu"), torch.device("cpu"))
try:
import tomesd
except ImportError:
pytest.skip("tomesd not installed")
success = patcher.apply_tome(ratio=0.5)
assert success, "Failed to apply ToMe"
# The block's class should now be ToMeBlock
assert block.__class__.__name__ == "ToMeBlock"
# Now try to call the block
x = torch.randn(1, 16, 64)
transformer_options = {"some_option": True}
# This should NOT raise TypeError: ToMeBlock._forward() takes from 2 to 3 positional arguments but 4 were given
# Even if it fails later due to mock issues, the TypeError should be gone.
try:
# We need to mock compute_merge or ensure it has what it needs
# tomesd.patch.hook_tome_model was called by apply_patch, which added _tome_info to diffusion_model
# and ToMeBlock._forward uses self._tome_info (which is a reference to diffusion_model._tome_info)
# Ensure _tome_info["size"] is set (normally set by hook)
mock_model.diffusion_model._tome_info["size"] = (64, 64)
block.forward(x, transformer_options=transformer_options)
except TypeError as e:
if "ToMeBlock._forward() takes" in str(e):
pytest.fail(f"ToMe fix failed: {e}")
# If it's another TypeError, it might be expected due to mocks
raise e
except Exception as e:
# We don't care about other errors (like shape mismatches in mocks)
# as long as the signature mismatch is fixed.
print(f"Caught expected non-TypeError: {e}")
if __name__ == "__main__":
test_tome_forward_signature()
|