| | import sys |
| | from typing import Dict, Any |
| | import copy |
| | import importlib |
| | from pathlib import Path |
| | from datetime import datetime |
| | from omegaconf.dictconfig import DictConfig |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def sample_indices_from_attention_mask_3d(attention_mask_3d: torch.Tensor) -> torch.Tensor: |
| | batch_size, seq_length, r = attention_mask_3d.shape |
| | x_flat = attention_mask_3d.view(-1, r) + 1e-5 |
| |
|
| | sum_ones = x_flat.sum(dim=1, keepdim=True) |
| | probs_flat = x_flat / sum_ones |
| |
|
| | |
| | indices_flat = torch.multinomial(probs_flat, num_samples=1) |
| |
|
| | return indices_flat.view(batch_size, seq_length, 1) |
| |
|
| |
|
| | def batch_masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | |
| | return torch.sum(x * mask) / mask.sum() |
| |
|
| |
|
| | def get_position_ids_from_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor: |
| | |
| | |
| | position_ids = torch.clamp_min(torch.cumsum(attention_mask, dim=1) - 1, 0) |
| | return position_ids |
| |
|
| |
|
| | def swap(a, b): |
| | return b, a |
| |
|
| |
|
| | def get_timestamp(): |
| | return datetime.now().strftime("%Y%m%d-%H%M%S") |
| |
|
| |
|
| | def get_obj_from_str(string, reload=False): |
| | module, cls = string.rsplit(".", 1) |
| | if reload: |
| | module_imp = importlib.import_module(module) |
| | importlib.reload(module_imp) |
| | return getattr(importlib.import_module(module, package=None), cls) |
| |
|
| |
|
| | def instantiate_from_config(config, extra_kwargs=dict()): |
| | config_dict = dict(config) |
| | if "target" not in config_dict: |
| | raise ValueError(f"target not found in {config}") |
| |
|
| | target_kwargs = copy.deepcopy(config_dict) |
| | target_kwargs.pop("target") |
| |
|
| | for k, v in target_kwargs.items(): |
| | if isinstance(v, DictConfig) and "target" in v.keys(): |
| | target_kwargs[k] = instantiate_from_config(v) |
| | target_kwargs.update(extra_kwargs) |
| |
|
| | return get_obj_from_str(config_dict["target"])(**target_kwargs) |
| |
|
| |
|
| | def dict_apply(x, func): |
| | result = dict() |
| | for key, value in x.items(): |
| | if isinstance(value, dict): |
| | result[key] = dict_apply(value, func) |
| | else: |
| | result[key] = func(value) |
| | return result |
| |
|
| |
|
| | def dict_to_device(d: Dict[str, Any], device): |
| | for k, v in d.items(): |
| | if isinstance(v, torch.Tensor): |
| | d[k] = v.to(device=device) |
| | return d |
| |
|
| |
|
| | def list_subdirs(path: Path): |
| | return [d for d in path.glob("*") if not d.is_file()] |
| |
|
| |
|
| | def is_debug_mode(): |
| | return hasattr(sys, "gettrace") and sys.gettrace() is not None |
| |
|
| |
|
| | def get_clones(module, N): |
| | return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
| |
|
| |
|
| | def get_metric_statistics(values, replication_times): |
| | mean = np.mean(values, axis=0) |
| | std = np.std(values, axis=0) |
| | conf_interval = 1.96 * std / np.sqrt(replication_times) |
| | return mean, conf_interval |
| |
|