Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # CREDITS: Initially suggested by Jason Ramapuram, see | |
| # https://github.com/facebookresearch/xformers/issues/203 | |
| import pickle | |
| from copy import deepcopy | |
| import pytest | |
| from torch import nn | |
| from xformers.factory import xFormer, xFormerConfig | |
| test_config = [ | |
| { | |
| "reversible": False, | |
| "block_type": "encoder", | |
| "num_layers": 2, | |
| "dim_model": 768, | |
| "residual_norm_style": "pre", | |
| "multi_head_config": { | |
| "num_heads": 12, | |
| "residual_dropout": 0.1, | |
| "use_rotary_embeddings": True, | |
| "attention": { | |
| "name": "scaled_dot_product", | |
| "dropout": 0.1, | |
| "causal": False, | |
| }, | |
| }, | |
| "feedforward_config": { | |
| "name": "MLP", | |
| "dropout": 0.1, | |
| "activation": "gelu", | |
| "hidden_layer_multiplier": 4, | |
| }, | |
| } | |
| ] | |
| class ViT(nn.Module): | |
| def __init__(self, mlp): | |
| super().__init__() | |
| test_config[0]["feedforward_config"]["name"] = mlp | |
| xformer_config = xFormerConfig(test_config) | |
| self.xformer = xFormer.from_config(xformer_config) | |
| MLPs = ["MLP"] | |
| def test_pickling(mlp): | |
| test = ViT(mlp) | |
| _ = pickle.dumps(test) | |
| _ = deepcopy(test) | |