| import torch | |
| from lambda_networks import LambdaLayer | |
| from lambda_networks import RLambdaLayer | |
| layer = LambdaLayer(dim=8, dim_out=8, r=23, dim_k=16, heads=4, dim_u=4) | |
| rlayer = RLambdaLayer(dim=8, dim_out=8, r=23, dim_k=16, heads=4, dim_u=4, recurrence=3) | |
| if __name__ == "__main__": | |
| x = torch.randn(1, 8, 64, 64, requires_grad=True) | |
| y = layer(x) | |
| z = rlayer(x) | |
| print(y.shape, z.shape) | |
| z.sum().backward() | |