| import torch | |
| import platform | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def __call__(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def getPlatform(): | |
| plt = platform.system() | |
| if plt=='Darwin': | |
| return 'mac' | |
| return plt | |
| def hasGPU(plt:str): | |
| if plt == 'mac': | |
| return torch.backends.mps.is_available() | |
| return torch.cuda.is_available() | |
| def getDevice(plt:str): | |
| if plt == 'mac': | |
| return torch.device('mps') | |
| return torch.device('cuda') | |
| def disableWarnings(): | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic") | |
| warnings.filterwarnings("ignore", category=UserWarning, module="trl.trainer.ppo_config") | |
| warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly") | |