File size: 820 Bytes
99ec8a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import torch.nn as nn
def layer_mapping(layer: str) -> nn.Module:
layer = layer.lower()
mappings = {
# Normalization
"batch": nn.BatchNorm2d,
"instance": nn.InstanceNorm2d,
"layer": lambda c: nn.GroupNorm(1, c), # nn.LayerNorm,
# Identity
"none": nn.Identity,
"linear": nn.Identity,
# Activations
"relu": nn.ReLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"softmax": lambda: nn.Softmax(dim=1),
}
try:
return mappings[layer]
except KeyError as e:
available = list(mappings.keys())
raise ValueError(
f"{layer} not found in the existing mapping."
f"Existing available: {available}"
f"Update model.model_utils.layer_mapping()"
) from e
|