File size: 2,569 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

import torch
import pytest
from unittest.mock import MagicMock
from src.NeuralNetwork.transformer import BasicTransformerBlock
from src.Model.ModelPatcher import ModelPatcher

def test_tome_forward_signature():
    # Create the block
    block = BasicTransformerBlock(dim=64, n_heads=1, d_head=64)
    
    # Setup a mock model with the necessary structure for tomesd
    class MockDiffusionModel(torch.nn.Module):
        def __init__(self, block):
            super().__init__()
            self.transformer_block = block
            self.dtype = torch.float32
    
    class MockModel(torch.nn.Module):
        def __init__(self, block):
            super().__init__()
            self.diffusion_model = MockDiffusionModel(block)
    
    mock_model = MockModel(block)
    
    # Use ModelPatcher to apply tome
    # ModelPatcher expects the model to have certain attributes
    patcher = ModelPatcher(mock_model, torch.device("cpu"), torch.device("cpu"))
    
    try:
        import tomesd
    except ImportError:
        pytest.skip("tomesd not installed")

    success = patcher.apply_tome(ratio=0.5)
    assert success, "Failed to apply ToMe"
    
    # The block's class should now be ToMeBlock
    assert block.__class__.__name__ == "ToMeBlock"
    
    # Now try to call the block
    x = torch.randn(1, 16, 64)
    transformer_options = {"some_option": True}
    
    # This should NOT raise TypeError: ToMeBlock._forward() takes from 2 to 3 positional arguments but 4 were given
    # Even if it fails later due to mock issues, the TypeError should be gone.
    try:
        # We need to mock compute_merge or ensure it has what it needs
        # tomesd.patch.hook_tome_model was called by apply_patch, which added _tome_info to diffusion_model
        # and ToMeBlock._forward uses self._tome_info (which is a reference to diffusion_model._tome_info)
        
        # Ensure _tome_info["size"] is set (normally set by hook)
        mock_model.diffusion_model._tome_info["size"] = (64, 64)
        
        block.forward(x, transformer_options=transformer_options)
    except TypeError as e:
        if "ToMeBlock._forward() takes" in str(e):
            pytest.fail(f"ToMe fix failed: {e}")
        # If it's another TypeError, it might be expected due to mocks
        raise e
    except Exception as e:
        # We don't care about other errors (like shape mismatches in mocks) 
        # as long as the signature mismatch is fixed.
        print(f"Caught expected non-TypeError: {e}")

if __name__ == "__main__":
    test_tome_forward_signature()