|
|
from typing import Callable |
|
|
|
|
|
import math |
|
|
import warnings |
|
|
|
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
def named_apply( |
|
|
fn: Callable, |
|
|
module: nn.Module, |
|
|
name: str = "", |
|
|
depth_first: bool = True, |
|
|
include_root: bool = False, |
|
|
) -> nn.Module: |
|
|
if not depth_first and include_root: |
|
|
fn(module=module, name=name) |
|
|
for child_name, child_module in module.named_children(): |
|
|
child_name = ".".join((name, child_name)) if name else child_name |
|
|
named_apply( |
|
|
fn=fn, |
|
|
module=child_module, |
|
|
name=child_name, |
|
|
depth_first=depth_first, |
|
|
include_root=True, |
|
|
) |
|
|
if depth_first and include_root: |
|
|
fn(module=module, name=name) |
|
|
return module |
|
|
|
|
|
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
|
"The distribution of values may be incorrect.", |
|
|
stacklevel=2) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
|
tensor.add_(mean) |
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
return tensor |
|
|
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
|
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |