File size: 1,888 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | """Load official E1 model from the e1 package for comparison."""
import torch
import torch.nn as nn
class _OfficialE1ForwardWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
**kwargs,
):
batch = {
"input_ids": input_ids,
"within_seq_position_ids": within_seq_position_ids,
"global_position_ids": global_position_ids,
"sequence_ids": sequence_ids,
}
outputs = self.model(**batch, output_hidden_states=True)
return outputs
def load_official_model(
reference_repo_id: str,
device: torch.device,
dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, object]:
"""Load the official E1 model from the e1 submodule.
Args:
reference_repo_id: e.g. "Profluent-Bio/E1-150m"
device: target device
dtype: target dtype (should be float32 for comparison)
Returns (official_model, batch_preparer) where batch_preparer is an E1BatchPreparer.
The official model is E1ForMaskedLM with standard HF forward interface.
"""
from E1.modeling import E1ForMaskedLM
from E1.batch_preparer import E1BatchPreparer
model = E1ForMaskedLM.from_pretrained(
reference_repo_id,
tie_word_embeddings=False,
device_map=device,
dtype=dtype,
).eval()
batch_preparer = E1BatchPreparer()
wrapped = _OfficialE1ForwardWrapper(model).eval()
return wrapped, batch_preparer
if __name__ == "__main__":
model, batch_preparer = load_official_model("Profluent-Bio/E1-150m", torch.device("cpu"))
print(model) |