nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
"""Load official E1 model from the e1 package for comparison."""
import torch
import torch.nn as nn
class _OfficialE1ForwardWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
input_ids: torch.LongTensor,
within_seq_position_ids: torch.LongTensor,
global_position_ids: torch.LongTensor,
sequence_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
**kwargs,
):
batch = {
"input_ids": input_ids,
"within_seq_position_ids": within_seq_position_ids,
"global_position_ids": global_position_ids,
"sequence_ids": sequence_ids,
}
outputs = self.model(**batch, output_hidden_states=True)
return outputs
def load_official_model(
reference_repo_id: str,
device: torch.device,
dtype: torch.dtype = torch.float32,
) -> tuple[nn.Module, object]:
"""Load the official E1 model from the e1 submodule.
Args:
reference_repo_id: e.g. "Profluent-Bio/E1-150m"
device: target device
dtype: target dtype (should be float32 for comparison)
Returns (official_model, batch_preparer) where batch_preparer is an E1BatchPreparer.
The official model is E1ForMaskedLM with standard HF forward interface.
"""
from E1.modeling import E1ForMaskedLM
from E1.batch_preparer import E1BatchPreparer
model = E1ForMaskedLM.from_pretrained(
reference_repo_id,
tie_word_embeddings=False,
device_map=device,
dtype=dtype,
).eval()
batch_preparer = E1BatchPreparer()
wrapped = _OfficialE1ForwardWrapper(model).eval()
return wrapped, batch_preparer
if __name__ == "__main__":
model, batch_preparer = load_official_model("Profluent-Bio/E1-150m", torch.device("cpu"))
print(model)