ius / model /module_mapping.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
import torch.nn as nn
def layer_mapping(layer: str) -> nn.Module:
layer = layer.lower()
mappings = {
# Normalization
"batch": nn.BatchNorm2d,
"instance": nn.InstanceNorm2d,
"layer": lambda c: nn.GroupNorm(1, c), # nn.LayerNorm,
# Identity
"none": nn.Identity,
"linear": nn.Identity,
# Activations
"relu": nn.ReLU,
"tanh": nn.Tanh,
"sigmoid": nn.Sigmoid,
"softmax": lambda: nn.Softmax(dim=1),
}
try:
return mappings[layer]
except KeyError as e:
available = list(mappings.keys())
raise ValueError(
f"{layer} not found in the existing mapping."
f"Existing available: {available}"
f"Update model.model_utils.layer_mapping()"
) from e