| |
| |
|
|
| import os |
| import sys |
| |
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.append(parent_dir) |
|
|
| import torch |
| from alias_free_activation.cuda import activation1d |
| from activations import Snake |
|
|
|
|
| def test_load_fused_kernels(): |
| try: |
| print("[Success] load_fused_kernels") |
| except ImportError as e: |
| print("[Fail] load_fused_kernels") |
| raise e |
|
|
|
|
| def test_anti_alias_activation(): |
| data = torch.rand((10, 10, 200), device="cuda") |
|
|
| |
| fused_anti_alias_activation = activation1d.Activation1d( |
| activation=Snake(10), fused=True |
| ).cuda() |
| fused_activation_output = fused_anti_alias_activation(data) |
|
|
| torch_anti_alias_activation = activation1d.Activation1d( |
| activation=Snake(10), fused=False |
| ).cuda() |
| torch_activation_output = torch_anti_alias_activation(data) |
|
|
| test_result = (fused_activation_output - torch_activation_output).abs() |
|
|
| while test_result.dim() != 1: |
| test_result = test_result.mean(dim=-1) |
|
|
| diff = test_result.mean(dim=-1) |
|
|
| if diff <= 1e-3: |
| print( |
| f"\n[Success] test_fused_anti_alias_activation" |
| f"\n > mean_difference={diff}" |
| f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}" |
| f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" |
| ) |
| else: |
| print( |
| f"\n[Fail] test_fused_anti_alias_activation" |
| f"\n > mean_difference={diff}, " |
| f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, " |
| f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| from alias_free_activation.cuda import load |
|
|
| load.load() |
| test_load_fused_kernels() |
| test_anti_alias_activation() |
|
|