|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers.hooks import HookRegistry, ModelHook |
|
|
from diffusers.training_utils import free_memory |
|
|
from diffusers.utils.logging import get_logger |
|
|
from diffusers.utils.testing_utils import CaptureLogger, torch_device |
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class DummyBlock(torch.nn.Module): |
|
|
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.proj_in = torch.nn.Linear(in_features, hidden_features) |
|
|
self.activation = torch.nn.ReLU() |
|
|
self.proj_out = torch.nn.Linear(hidden_features, out_features) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.proj_in(x) |
|
|
x = self.activation(x) |
|
|
x = self.proj_out(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DummyModel(torch.nn.Module): |
|
|
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.linear_1 = torch.nn.Linear(in_features, hidden_features) |
|
|
self.activation = torch.nn.ReLU() |
|
|
self.blocks = torch.nn.ModuleList( |
|
|
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] |
|
|
) |
|
|
self.linear_2 = torch.nn.Linear(hidden_features, out_features) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.linear_1(x) |
|
|
x = self.activation(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.linear_2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class AddHook(ModelHook): |
|
|
def __init__(self, value: int): |
|
|
super().__init__() |
|
|
self.value = value |
|
|
|
|
|
def pre_forward(self, module: torch.nn.Module, *args, **kwargs): |
|
|
logger.debug("AddHook pre_forward") |
|
|
args = ((x + self.value) if torch.is_tensor(x) else x for x in args) |
|
|
return args, kwargs |
|
|
|
|
|
def post_forward(self, module, output): |
|
|
logger.debug("AddHook post_forward") |
|
|
return output |
|
|
|
|
|
|
|
|
class MultiplyHook(ModelHook): |
|
|
def __init__(self, value: int): |
|
|
super().__init__() |
|
|
self.value = value |
|
|
|
|
|
def pre_forward(self, module, *args, **kwargs): |
|
|
logger.debug("MultiplyHook pre_forward") |
|
|
args = ((x * self.value) if torch.is_tensor(x) else x for x in args) |
|
|
return args, kwargs |
|
|
|
|
|
def post_forward(self, module, output): |
|
|
logger.debug("MultiplyHook post_forward") |
|
|
return output |
|
|
|
|
|
def __repr__(self): |
|
|
return f"MultiplyHook(value={self.value})" |
|
|
|
|
|
|
|
|
class StatefulAddHook(ModelHook): |
|
|
_is_stateful = True |
|
|
|
|
|
def __init__(self, value: int): |
|
|
super().__init__() |
|
|
self.value = value |
|
|
self.increment = 0 |
|
|
|
|
|
def pre_forward(self, module, *args, **kwargs): |
|
|
logger.debug("StatefulAddHook pre_forward") |
|
|
add_value = self.value + self.increment |
|
|
self.increment += 1 |
|
|
args = ((x + add_value) if torch.is_tensor(x) else x for x in args) |
|
|
return args, kwargs |
|
|
|
|
|
def reset_state(self, module): |
|
|
self.increment = 0 |
|
|
|
|
|
|
|
|
class SkipLayerHook(ModelHook): |
|
|
def __init__(self, skip_layer: bool): |
|
|
super().__init__() |
|
|
self.skip_layer = skip_layer |
|
|
|
|
|
def pre_forward(self, module, *args, **kwargs): |
|
|
logger.debug("SkipLayerHook pre_forward") |
|
|
return args, kwargs |
|
|
|
|
|
def new_forward(self, module, *args, **kwargs): |
|
|
logger.debug("SkipLayerHook new_forward") |
|
|
if self.skip_layer: |
|
|
return args[0] |
|
|
return self.fn_ref.original_forward(*args, **kwargs) |
|
|
|
|
|
def post_forward(self, module, output): |
|
|
logger.debug("SkipLayerHook post_forward") |
|
|
return output |
|
|
|
|
|
|
|
|
class HookTests(unittest.TestCase): |
|
|
in_features = 4 |
|
|
hidden_features = 8 |
|
|
out_features = 4 |
|
|
num_layers = 2 |
|
|
|
|
|
def setUp(self): |
|
|
params = self.get_module_parameters() |
|
|
self.model = DummyModel(**params) |
|
|
self.model.to(torch_device) |
|
|
|
|
|
def tearDown(self): |
|
|
super().tearDown() |
|
|
|
|
|
del self.model |
|
|
gc.collect() |
|
|
free_memory() |
|
|
|
|
|
def get_module_parameters(self): |
|
|
return { |
|
|
"in_features": self.in_features, |
|
|
"hidden_features": self.hidden_features, |
|
|
"out_features": self.out_features, |
|
|
"num_layers": self.num_layers, |
|
|
} |
|
|
|
|
|
def get_generator(self): |
|
|
return torch.manual_seed(0) |
|
|
|
|
|
def test_hook_registry(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(AddHook(1), "add_hook") |
|
|
registry.register_hook(MultiplyHook(2), "multiply_hook") |
|
|
|
|
|
registry_repr = repr(registry) |
|
|
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)" |
|
|
|
|
|
self.assertEqual(len(registry.hooks), 2) |
|
|
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"]) |
|
|
self.assertEqual(registry_repr, expected_repr) |
|
|
|
|
|
registry.remove_hook("add_hook") |
|
|
|
|
|
self.assertEqual(len(registry.hooks), 1) |
|
|
self.assertEqual(registry._hook_order, ["multiply_hook"]) |
|
|
|
|
|
def test_stateful_hook(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(StatefulAddHook(1), "stateful_add_hook") |
|
|
|
|
|
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0) |
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
|
|
num_repeats = 3 |
|
|
|
|
|
for i in range(num_repeats): |
|
|
result = self.model(input) |
|
|
if i == 0: |
|
|
output1 = result |
|
|
|
|
|
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats) |
|
|
|
|
|
registry.reset_stateful_hooks() |
|
|
output2 = self.model(input) |
|
|
|
|
|
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1) |
|
|
self.assertTrue(torch.allclose(output1, output2)) |
|
|
|
|
|
def test_inference(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(AddHook(1), "add_hook") |
|
|
registry.register_hook(MultiplyHook(2), "multiply_hook") |
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
|
|
output1 = self.model(input).mean().detach().cpu().item() |
|
|
|
|
|
registry.remove_hook("multiply_hook") |
|
|
new_input = input * 2 |
|
|
output2 = self.model(new_input).mean().detach().cpu().item() |
|
|
|
|
|
registry.remove_hook("add_hook") |
|
|
new_input = input * 2 + 1 |
|
|
output3 = self.model(new_input).mean().detach().cpu().item() |
|
|
|
|
|
self.assertAlmostEqual(output1, output2, places=5) |
|
|
self.assertAlmostEqual(output1, output3, places=5) |
|
|
|
|
|
def test_skip_layer_hook(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") |
|
|
|
|
|
input = torch.zeros(1, 4, device=torch_device) |
|
|
output = self.model(input).mean().detach().cpu().item() |
|
|
self.assertEqual(output, 0.0) |
|
|
|
|
|
registry.remove_hook("skip_layer_hook") |
|
|
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook") |
|
|
output = self.model(input).mean().detach().cpu().item() |
|
|
self.assertNotEqual(output, 0.0) |
|
|
|
|
|
def test_skip_layer_internal_block(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1) |
|
|
input = torch.zeros(1, 4, device=torch_device) |
|
|
|
|
|
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") |
|
|
with self.assertRaises(RuntimeError) as cm: |
|
|
self.model(input).mean().detach().cpu().item() |
|
|
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception)) |
|
|
|
|
|
registry.remove_hook("skip_layer_hook") |
|
|
output = self.model(input).mean().detach().cpu().item() |
|
|
self.assertNotEqual(output, 0.0) |
|
|
|
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1]) |
|
|
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook") |
|
|
output = self.model(input).mean().detach().cpu().item() |
|
|
self.assertNotEqual(output, 0.0) |
|
|
|
|
|
def test_invocation_order_stateful_first(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(StatefulAddHook(1), "add_hook") |
|
|
registry.register_hook(AddHook(2), "add_hook_2") |
|
|
registry.register_hook(MultiplyHook(3), "multiply_hook") |
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger.setLevel("DEBUG") |
|
|
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
( |
|
|
"MultiplyHook pre_forward\n" |
|
|
"AddHook pre_forward\n" |
|
|
"StatefulAddHook pre_forward\n" |
|
|
"AddHook post_forward\n" |
|
|
"MultiplyHook post_forward\n" |
|
|
) |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
registry.remove_hook("add_hook") |
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n") |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
def test_invocation_order_stateful_middle(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(AddHook(2), "add_hook") |
|
|
registry.register_hook(StatefulAddHook(1), "add_hook_2") |
|
|
registry.register_hook(MultiplyHook(3), "multiply_hook") |
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger.setLevel("DEBUG") |
|
|
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
( |
|
|
"MultiplyHook pre_forward\n" |
|
|
"StatefulAddHook pre_forward\n" |
|
|
"AddHook pre_forward\n" |
|
|
"AddHook post_forward\n" |
|
|
"MultiplyHook post_forward\n" |
|
|
) |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
registry.remove_hook("add_hook") |
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
("MultiplyHook pre_forward\nStatefulAddHook pre_forward\nMultiplyHook post_forward\n") |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
registry.remove_hook("add_hook_2") |
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
def test_invocation_order_stateful_last(self): |
|
|
registry = HookRegistry.check_if_exists_or_initialize(self.model) |
|
|
registry.register_hook(AddHook(1), "add_hook") |
|
|
registry.register_hook(MultiplyHook(2), "multiply_hook") |
|
|
registry.register_hook(StatefulAddHook(3), "add_hook_2") |
|
|
|
|
|
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator()) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger.setLevel("DEBUG") |
|
|
|
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
( |
|
|
"StatefulAddHook pre_forward\n" |
|
|
"MultiplyHook pre_forward\n" |
|
|
"AddHook pre_forward\n" |
|
|
"AddHook post_forward\n" |
|
|
"MultiplyHook post_forward\n" |
|
|
) |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|
|
|
registry.remove_hook("add_hook") |
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
self.model(input) |
|
|
output = cap_logger.out.replace(" ", "").replace("\n", "") |
|
|
expected_invocation_order_log = ( |
|
|
("StatefulAddHook pre_forward\nMultiplyHook pre_forward\nMultiplyHook post_forward\n") |
|
|
.replace(" ", "") |
|
|
.replace("\n", "") |
|
|
) |
|
|
self.assertEqual(output, expected_invocation_order_log) |
|
|
|