Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Optional | |
| from torch import nn | |
| from einops import rearrange | |
| class BottleneckBlock(nn.Module): | |
| def __init__(self, thin, wide): | |
| super(BottleneckBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.Linear(thin, wide), | |
| nn.GELU(), | |
| nn.Linear(wide, thin) | |
| ) | |
| def forward(self, x): | |
| out = self.block(x) | |
| return out | |
| class StandardMLP(nn.Module): | |
| def __init__(self, dim_in, dim_out, widths): | |
| super(StandardMLP, self).__init__() | |
| self.dim_in = dim_in | |
| self.dim_out = dim_out | |
| self.widths = widths | |
| self.linear_in = nn.Linear(self.dim_in, self.widths[0]) | |
| self.linear_out = nn.Linear(self.widths[-1], self.dim_out) | |
| self.layers = [] | |
| self.layer_norms = [] | |
| for i in range(len(self.widths) - 1): | |
| self.layers.append(nn.Linear(self.widths[i], self.widths[i + 1])) | |
| self.layer_norms.append(nn.LayerNorm(widths[i + 1])) | |
| self.layers = nn.ModuleList(self.layers) | |
| self.layernorms = nn.ModuleList(self.layer_norms) | |
| def forward(self, x): | |
| # If x is an image, apply MLP point-wise to each token/pixel | |
| if x.ndim == 4: | |
| _, _, H, W = x.shape | |
| x = rearrange(x, 'b d h w -> b (h w) d') | |
| x_is_image = True | |
| else: | |
| x_is_image = False | |
| z = self.linear_in(x) | |
| for layer, norm in zip(self.layers, self.layer_norms): | |
| z = norm(z) | |
| z = layer(z) | |
| out = self.linear_out(z) | |
| # If x was an image, rearrange back to image | |
| if x_is_image: | |
| out = rearrange(out, 'b (h w) d -> b d h w', h=H, w=W) | |
| return out | |
| class BottleneckMLP(nn.Module): | |
| def __init__(self, dim_in, dim_out, block_dims): | |
| super(BottleneckMLP, self).__init__() | |
| self.dim_in = dim_in | |
| self.dim_out = dim_out | |
| self.block_dims = block_dims | |
| self.linear_in = nn.Linear(self.dim_in, self.block_dims[0][1]) | |
| self.linear_out = nn.Linear(self.block_dims[-1][1], self.dim_out) | |
| blocks = [] | |
| layernorms = [] | |
| for block_dim in self.block_dims: | |
| wide, thin = block_dim | |
| blocks.append(BottleneckBlock(thin=thin, wide=wide)) | |
| layernorms.append(nn.LayerNorm(thin)) | |
| self.blocks = nn.ModuleList(blocks) | |
| self.layernorms = nn.ModuleList(layernorms) | |
| def forward(self, x): | |
| # If x is an image, apply MLP point-wise to each token/pixel | |
| if x.ndim == 4: | |
| _, _, H, W = x.shape | |
| x = rearrange(x, 'b d h w -> b (h w) d') | |
| x_is_image = True | |
| else: | |
| x_is_image = False | |
| x = self.linear_in(x) | |
| for block, norm in zip(self.blocks, self.layernorms): | |
| x = x + block(norm(x)) | |
| out = self.linear_out(x) | |
| # If x was an image, rearrange back to image | |
| if x_is_image: | |
| out = rearrange(out, 'b (h w) d -> b d h w', h=H, w=W) | |
| return out | |
| def build_mlp(model_id: str = "BottleneckMLP/B_6-Wi_1024", | |
| dim_in: Optional[int] = None, | |
| dim_out: Optional[int] = None, | |
| **kwargs) -> nn.Module: | |
| """Constructs an MLP model from a model ID string, see | |
| "Scaling MLPs: A Tale of Inductive Bias" (https://arxiv.org/abs/2306.13575). | |
| Args: | |
| model_id: Model ID string. E.g. "BottleneckMLP/B_6-Wi_1024". | |
| See https://arxiv.org/abs/2306.13575 for options and details. | |
| dim_in: Input dimensionality. If None, defaults to MLP dimension. | |
| dim_out: Output dimensionality. If None, defaults to MLP dimension. | |
| Returns: | |
| MLP model. | |
| """ | |
| model, architecture = model_id.split("/") | |
| assert model in ["BottleneckMLP", "MLP"], f"Model {model} not supported." | |
| sep = architecture.split("-") | |
| num_blocks = int(sep[0].split("_")[1]) | |
| thin = int(sep[1].split("_")[1]) | |
| # If dim_in and dim_out are not specified, use MLP dim | |
| dim_in = dim_in or thin | |
| dim_out = dim_out or thin | |
| if len(sep) == 3: | |
| expansion_factor = int(sep[2].split("_")[1]) | |
| else: | |
| expansion_factor = 4 | |
| if model == "BottleneckMLP": | |
| blocks = [[expansion_factor * thin, thin] for _ in range(num_blocks)] | |
| return BottleneckMLP( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| block_dims=blocks, | |
| ) | |
| elif model == "MLP": | |
| blocks = [thin for _ in range(num_blocks)] | |
| return StandardMLP( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| widths=blocks, | |
| ) |