| import torch | |
| HALF = 'torch.cuda.HalfTensor' | |
| FLOAT = 'torch.cuda.FloatTensor' | |
| DTYPES = [torch.half, torch.float] | |
| ALWAYS_HALF = {torch.float: HALF, | |
| torch.half: HALF} | |
| ALWAYS_FLOAT = {torch.float: FLOAT, | |
| torch.half: FLOAT} | |
| MATCH_INPUT = {torch.float: FLOAT, | |
| torch.half: HALF} | |
| def common_init(test_case): | |
| test_case.h = 64 | |
| test_case.b = 16 | |
| test_case.c = 16 | |
| test_case.k = 3 | |
| test_case.t = 10 | |
| torch.set_default_tensor_type(torch.cuda.FloatTensor) | |