Spaces:
Running
Running
| import time | |
| import torch | |
| from hpc_rll.origin.gae import gae, gae_data | |
| from hpc_rll.rl_utils.gae import GAE | |
| from testbase import mean_relative_error, times | |
| assert torch.cuda.is_available() | |
| use_cuda = True | |
| T = 1024 | |
| B = 64 | |
| def gae_val(): | |
| value = torch.randn(T + 1, B) | |
| reward = torch.randn(T, B) | |
| hpc_gae = GAE(T, B) | |
| if use_cuda: | |
| value = value.cuda() | |
| reward = reward.cuda() | |
| hpc_gae = hpc_gae.cuda() | |
| ori_adv = gae(gae_data(value, reward)) | |
| hpc_adv = hpc_gae(value, reward) | |
| if use_cuda: | |
| torch.cuda.synchronize() | |
| mre = mean_relative_error( | |
| torch.flatten(ori_adv).cpu().detach().numpy(), | |
| torch.flatten(hpc_adv).cpu().detach().numpy() | |
| ) | |
| print("gae mean_relative_error: " + str(mre)) | |
| def gae_perf(): | |
| value = torch.randn(T + 1, B) | |
| reward = torch.randn(T, B) | |
| hpc_gae = GAE(T, B) | |
| if use_cuda: | |
| value = value.cuda() | |
| reward = reward.cuda() | |
| hpc_gae = hpc_gae.cuda() | |
| for i in range(times): | |
| t = time.time() | |
| adv = gae(gae_data(value, reward)) | |
| if use_cuda: | |
| torch.cuda.synchronize() | |
| print('epoch: {}, original gae cost time: {}'.format(i, time.time() - t)) | |
| for i in range(times): | |
| t = time.time() | |
| hpc_adv = hpc_gae(value, reward) | |
| if use_cuda: | |
| torch.cuda.synchronize() | |
| print('epoch: {}, hpc gae cost time: {}'.format(i, time.time() - t)) | |
| if __name__ == '__main__': | |
| print("target problem: T = {}, B = {}".format(T, B)) | |
| print("================run gae validation test================") | |
| gae_val() | |
| print("================run gae performance test================") | |
| gae_perf() | |