zz1358m's picture
Upload folder using huggingface_hub
7a92ec5 verified
import sys
from typing import Dict, Any
import copy
import importlib
from pathlib import Path
from datetime import datetime
from omegaconf.dictconfig import DictConfig
import numpy as np
import torch
def sample_indices_from_attention_mask_3d(attention_mask_3d: torch.Tensor) -> torch.Tensor:
batch_size, seq_length, r = attention_mask_3d.shape
x_flat = attention_mask_3d.view(-1, r) + 1e-5 # Flatten to (batch_size * seq_length, r), avoid [0, 0, 0, 0]
sum_ones = x_flat.sum(dim=1, keepdim=True)
probs_flat = x_flat / sum_ones # Normalize to get probabilities
# Sample one index per flattened position
indices_flat = torch.multinomial(probs_flat, num_samples=1) # (batch_size * seq_length, 1)
return indices_flat.view(batch_size, seq_length, 1)
def batch_masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# x: [B, T]
# mask: [B, T]
# return: []
return torch.sum(x * mask) / mask.sum()
def get_position_ids_from_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
# attention_mask: [B, T]
# position_ids: [B, T]
position_ids = torch.clamp_min(torch.cumsum(attention_mask, dim=1) - 1, 0)
return position_ids
def swap(a, b):
return b, a
def get_timestamp():
return datetime.now().strftime("%Y%m%d-%H%M%S")
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config, extra_kwargs=dict()):
config_dict = dict(config)
if "target" not in config_dict:
raise ValueError(f"target not found in {config}")
target_kwargs = copy.deepcopy(config_dict)
target_kwargs.pop("target")
for k, v in target_kwargs.items():
if isinstance(v, DictConfig) and "target" in v.keys():
target_kwargs[k] = instantiate_from_config(v)
target_kwargs.update(extra_kwargs)
return get_obj_from_str(config_dict["target"])(**target_kwargs)
def dict_apply(x, func):
result = dict()
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def dict_to_device(d: Dict[str, Any], device):
for k, v in d.items():
if isinstance(v, torch.Tensor):
d[k] = v.to(device=device)
return d
def list_subdirs(path: Path):
return [d for d in path.glob("*") if not d.is_file()]
def is_debug_mode():
return hasattr(sys, "gettrace") and sys.gettrace() is not None
def get_clones(module, N):
return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def get_metric_statistics(values, replication_times):
mean = np.mean(values, axis=0)
std = np.std(values, axis=0)
conf_interval = 1.96 * std / np.sqrt(replication_times)
return mean, conf_interval