import torch.nn as nn def layer_mapping(layer: str) -> nn.Module: layer = layer.lower() mappings = { # Normalization "batch": nn.BatchNorm2d, "instance": nn.InstanceNorm2d, "layer": lambda c: nn.GroupNorm(1, c), # nn.LayerNorm, # Identity "none": nn.Identity, "linear": nn.Identity, # Activations "relu": nn.ReLU, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "softmax": lambda: nn.Softmax(dim=1), } try: return mappings[layer] except KeyError as e: available = list(mappings.keys()) raise ValueError( f"{layer} not found in the existing mapping." f"Existing available: {available}" f"Update model.model_utils.layer_mapping()" ) from e