Spaces:
Running
Running
| import torch | |
| import omegaconf | |
| from omegaconf import open_dict | |
| def print_model_info(model): | |
| """Prints model parameters and their total count""" | |
| total_params = 0 | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| dims = list(param.data.size()) | |
| local_params = 1 | |
| for dim in dims: | |
| local_params *= dim | |
| total_params += local_params | |
| if not ("lm_encoder." in name): | |
| print(name, param.data.size()) | |
| print("\nTotal Params:{:.2f} (in millions)".format(total_params / 10**6)) | |
| def enough_memory(): | |
| if torch.cuda.is_available(): | |
| memory_in_gb = torch.cuda.get_device_properties(0).total_memory // (1024**3) | |
| if memory_in_gb > 40: | |
| return True | |
| return False | |
| def get_sequence_mask(sequence_len): | |
| """Returns Sequence Mask. | |
| sequence_len: Tensor of size (B,) with entries indicating length of seq. | |
| """ | |
| batch_size = sequence_len.size()[0] | |
| max_len = torch.max(sequence_len) | |
| tmp = torch.arange(max_len, device=sequence_len.device).expand(batch_size, max_len) | |
| return tmp < sequence_len.unsqueeze(1) | |
| def get_span_mask(start_ids, end_ids, max_len): | |
| tmp = ( | |
| torch.arange(max_len, device=start_ids.device) | |
| .unsqueeze(0) | |
| .expand(start_ids.shape[0], -1) | |
| ) | |
| batch_start_ids = start_ids.unsqueeze(1).expand_as(tmp) | |
| batch_end_ids = end_ids.unsqueeze(1).expand_as(tmp) | |
| mask = (tmp >= batch_start_ids).float() * (tmp <= batch_end_ids).float() | |
| return mask | |
| def check_nan_grad(model): | |
| for name, param in model.named_parameters(): | |
| if param.grad is None: | |
| continue | |
| else: | |
| num_nan = torch.sum(torch.isnan(param.grad.data)) | |
| if num_nan: | |
| print(name) | |
| def get_l2_norm(model, debug=False): | |
| total_l2_norm = {"param": 0, "grad": 0} | |
| param_norm_list = [] | |
| for name, param in model.named_parameters(): | |
| if param.grad is not None: | |
| param_norm = torch.norm(param.data, p=2) | |
| if torch.isnan(param_norm): | |
| print("NaN parameter:", name) | |
| param_norm_list.append((name, param_norm.item())) | |
| total_l2_norm["param"] += torch.norm(param.data, p=2).item() | |
| total_l2_norm["grad"] += torch.norm(param.grad.data, p=2).item() | |
| if debug: | |
| print("Summation of L2 norm: %.3f" % total_l2_norm["param"]) | |
| # Sort param list by L2 norm | |
| sorted_param_list = sorted(param_norm_list, key=lambda x: x[1], reverse=True) | |
| topk_list = sorted_param_list[:5] | |
| for name, param_norm in topk_list: | |
| print( | |
| "Name: %s\tNorm (%%): %.3f\tNorm: %.3f" | |
| % (name, param_norm * 100 / total_l2_norm["param"], param_norm) | |
| ) | |
| return total_l2_norm | |
| def fill_missing_configs(dict1, dict2): | |
| with open_dict(dict1): | |
| for key, value in dict2.items(): | |
| if key not in dict1: | |
| dict1[key] = value | |
| print(f"Added {key} to config, with value {value}") | |
| elif isinstance(value, omegaconf.dictconfig.DictConfig): | |
| dict1[key] = fill_missing_configs(dict1[key], value) | |
| return dict1 | |