File size: 5,696 Bytes
9838d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ddb5a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9838d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
---

library_name: transformers
license: mit
tags:
  - biology
  - protein-language-model
  - esm3
  - multimodal-protein-model
---


# FastPLMs ESM3 Small

FastPLMs ESM3 Small is a Hugging Face compatible implementation of Biohub's open ESM3 small model. It loads through `AutoModel`, supports sequence-only inference by default, and exposes ESM3's additional tensor tracks directly through normal keyword arguments.

This repository includes the Biohub ESM MIT license in `LICENSE`.

## Use With Transformers

```python

import torch

from transformers import AutoModel



model = AutoModel.from_pretrained(

    "Synthyra/ESM3_small",

    trust_remote_code=True,

    dtype=torch.bfloat16,

    device_map="cuda",

).eval()



sequences = ["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"]

tokens = model.tokenize_sequences(sequences, device=model.device)



with torch.inference_mode():

    output = model(**tokens)



print(output.logits.shape) # sequence logits, (batch_size, seq_len, 64)

print(output.last_hidden_state.shape) # ESM3 embeddings, (batch_size, seq_len, hidden_size)

print(output.function_logits.shape) # function logits, (batch_size, seq_len, 8, 260)

```

You can also call sequence inference directly:

```python

output = model.forward_sequence(["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"])

```

## Experimental Test-Time Training

TTT is disabled by default. No LoRA adapters are injected during normal
`forward_sequence`, `forward`, or `embed_dataset` calls. Calling `model.ttt(...)`
opts in to experimental masked-LM adaptation of the ESM3 sequence track through
local LoRA weights. It can improve some difficult proteins, but it adds
test-time compute and can degrade already confident predictions.

```python

metrics = model.ttt(

    seq="MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP",

    ttt_config={"steps": 3, "ags": 1, "batch_size": 1},

)

model.ttt_reset()

print(metrics["losses"])

```

Switch between SDPA and Flex Attention after loading:

```python

model.attn_backend = "flex"

output = model.forward_sequence(["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"])

model.attn_backend = "sdpa"

```

## Embed Entire Datasets

To embed a list of protein sequences, call `embed_dataset`. Sequences are deduplicated, sorted by length, optionally truncated, and embedded in batches.

```python

embedding_dict = model.embed_dataset(

    sequences=[

        "MALWMRLLPLLALLALWGPDPAAA",

        "MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP",

    ],

    batch_size=2,

    max_len=512,

    full_embeddings=False,

    embed_dtype=torch.float32,

    pooling_types=["mean", "cls"],

    save=True,

    save_path="esm3_embeddings.pth",

)



# embedding_dict maps sequence strings to pooled tensors.

print(embedding_dict["MALWMRLLPLLALLALWGPDPAAA"].shape)

```

Residue-wise embeddings are available by setting `full_embeddings=True`:

```python

residue_embeddings = model.embed_dataset(

    sequences=["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"],

    batch_size=1,

    max_len=512,

    full_embeddings=True,

    save=False,

)



print(residue_embeddings["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"].shape)

```

FASTA input is also supported:

```python

embedding_dict = model.embed_dataset(

    fasta_path="proteins.fasta",

    batch_size=4,

    pooling_types=["mean"],

    save_path="esm3_fasta_embeddings.pth",

)

```

`embed_dataset` currently supports pooled `mean`, `cls`, and `max` embeddings, plus unpooled residue embeddings. It supports `.pth` saves; SQLite streaming is not enabled for the ESM3 wrapper yet.

## Multimodal Track Arguments

The default path is amino acid sequence inference. Additional ESM3 tracks can be supplied directly using the same tensor shapes as Biohub ESM3:

```python

tokens = model.tokenize_sequences(

    ["MKTAYIAKQRQISFVKSHFSRQDILDLWIYHTQGYFP"],

    device=model.device,

)



function_tokens = tokens["input_ids"].new_zeros((*tokens["input_ids"].shape, 8))



with torch.inference_mode():

    output = model(

        **tokens,

        function_tokens=function_tokens,

    )



print(output.sequence_logits.shape)

print(output.function_logits.shape)

```

Accepted track arguments include `sequence_tokens`, `structure_tokens`, `ss8_tokens`, `sasa_tokens`, `function_tokens`, `residue_annotation_tokens`, `average_plddt`, `per_res_plddt`, `structure_coords`, `chain_id`, and `sequence_id`. `input_ids` aliases `sequence_tokens`, and `attention_mask` is converted into `sequence_id` if no explicit `sequence_id` is provided.

## Loading Biohub Checkpoints Locally

You can build the FastPLMs wrapper from the Biohub checkpoint directly:

```python

from fastplms.esm3.modeling_esm3 import FastESM3Model



model = FastESM3Model.from_pretrained_esm("esm3-sm-open-v1", device="cuda")

```

This requires Hugging Face access to the gated `biohub/esm3-sm-open-v1` source repo.

## Biohub SDK Compatibility

The core forward path is self-contained. Higher-level Biohub SDK workflows are delegated lazily to the official `esm` submodule when available:

```python

# These methods use Biohub SDK dataclasses and generation configs.

encoded = model.encode(esm_protein)

decoded = model.decode(encoded)

generated = model.generate(esm_protein, generation_config)

```

Available delegated methods include `encode`, `decode`, `generate`, `batch_generate`, `logits`, and `forward_and_sample`.

## Source

- Biohub ESM repository: https://github.com/Biohub/esm
- Biohub ESM license: https://github.com/Biohub/esm/blob/main/LICENSE.md
- Paper: https://biohub.ai/papers/esm_protein.pdf

- Official model source: https://huggingface.co/biohub/esm3-sm-open-v1