"""Load official E1 model from the e1 package for comparison.""" import torch import torch.nn as nn class _OfficialE1ForwardWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward( self, input_ids: torch.LongTensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs, ): batch = { "input_ids": input_ids, "within_seq_position_ids": within_seq_position_ids, "global_position_ids": global_position_ids, "sequence_ids": sequence_ids, } outputs = self.model(**batch, output_hidden_states=True) return outputs def load_official_model( reference_repo_id: str, device: torch.device, dtype: torch.dtype = torch.float32, ) -> tuple[nn.Module, object]: """Load the official E1 model from the e1 submodule. Args: reference_repo_id: e.g. "Profluent-Bio/E1-150m" device: target device dtype: target dtype (should be float32 for comparison) Returns (official_model, batch_preparer) where batch_preparer is an E1BatchPreparer. The official model is E1ForMaskedLM with standard HF forward interface. """ from E1.modeling import E1ForMaskedLM from E1.batch_preparer import E1BatchPreparer model = E1ForMaskedLM.from_pretrained( reference_repo_id, tie_word_embeddings=False, device_map=device, dtype=dtype, ).eval() batch_preparer = E1BatchPreparer() wrapped = _OfficialE1ForwardWrapper(model).eval() return wrapped, batch_preparer if __name__ == "__main__": model, batch_preparer = load_official_model("Profluent-Bio/E1-150m", torch.device("cpu")) print(model)