Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | |
| from fairscale.nn.wrap import enable_wrap, wrap | |
| import esm | |
| # init the distributed world with world_size 1 | |
| url = "tcp://localhost:23456" | |
| torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0) | |
| # download model data from the hub | |
| model_name = "esm2_t48_15B_UR50D" | |
| model_data, regression_data = esm.pretrained._download_model_and_regression_data(model_name) | |
| # initialize the model with FSDP wrapper | |
| fsdp_params = dict( | |
| mixed_precision=True, | |
| flatten_parameters=True, | |
| state_dict_device=torch.device("cpu"), # reduce GPU mem usage | |
| cpu_offload=True, # enable cpu offloading | |
| ) | |
| with enable_wrap(wrapper_cls=FSDP, **fsdp_params): | |
| model, vocab = esm.pretrained.load_model_and_alphabet_core( | |
| model_name, model_data, regression_data | |
| ) | |
| batch_converter = vocab.get_batch_converter() | |
| model.eval() | |
| # Wrap each layer in FSDP separately | |
| for name, child in model.named_children(): | |
| if name == "layers": | |
| for layer_name, layer in child.named_children(): | |
| wrapped_layer = wrap(layer) | |
| setattr(child, layer_name, wrapped_layer) | |
| model = wrap(model) | |
| data = [ | |
| ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), | |
| ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"), | |
| ( | |
| "protein2 with mask", | |
| "KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE", | |
| ), | |
| ("protein3", "K A <mask> I S Q"), | |
| ] | |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
| batch_tokens = batch_tokens.cuda() | |
| with torch.no_grad(): | |
| results = model(batch_tokens, repr_layers=[48], return_contacts=True) | |
| print(results) | |