|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
|
|
|
class CustomSimpleNamespace: |
|
|
""" |
|
|
A simple namespace class that supports both attribute-style and dictionary-style access. |
|
|
""" |
|
|
|
|
|
def __init__(self, d): |
|
|
self._d = d |
|
|
|
|
|
def __getattr__(self, attr): |
|
|
|
|
|
try: |
|
|
return self._d[attr] |
|
|
except KeyError: |
|
|
raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'") |
|
|
|
|
|
def __getitem__(self, key): |
|
|
|
|
|
return self._d[key] |
|
|
|
|
|
|
|
|
def maybe_convert_to_namespace(config): |
|
|
""" |
|
|
This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both |
|
|
attribute-style and dictionary-style access. |
|
|
Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile. |
|
|
""" |
|
|
|
|
|
if isinstance(config, DictConfig): |
|
|
config = OmegaConf.to_container(config, resolve=True) |
|
|
|
|
|
if isinstance(config, dict): |
|
|
return CustomSimpleNamespace(config) |
|
|
else: |
|
|
return config |
|
|
|
|
|
|
|
|
def random_dropout(embeddings, drop_rate): |
|
|
r""" |
|
|
Function to perform random dropout for embeddings. |
|
|
When we drop embeddings, we zero them out. |
|
|
Args: |
|
|
embeddings (tensor): Input embeddings |
|
|
drop_rate (float): Rate of dropping the embedding. |
|
|
""" |
|
|
num_samples = embeddings.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_shape = (num_samples,) + tuple([1] * (embeddings.ndim - 1)) |
|
|
zero_flag = torch.ones(tensor_shape).to(embeddings.dtype) * (1 - drop_rate) |
|
|
zero_flag = torch.bernoulli(zero_flag).to(embeddings.device) |
|
|
embeddings = embeddings * zero_flag |
|
|
return embeddings |
|
|
|