Commit
·
071db43
1
Parent(s):
01a1f08
model
Browse files
model.py
CHANGED
|
@@ -14,7 +14,7 @@ import gc
|
|
| 14 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 15 |
from transformers import EsmModel, PreTrainedModel
|
| 16 |
from configuration import MetaLATTEConfig
|
| 17 |
-
|
| 18 |
seed_everything(42)
|
| 19 |
|
| 20 |
class GELU(nn.Module):
|
|
@@ -226,9 +226,19 @@ class MultitaskProteinModel(PreTrainedModel):
|
|
| 226 |
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
| 227 |
|
| 228 |
model = cls(config)
|
| 229 |
-
state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return model
|
|
|
|
| 232 |
|
| 233 |
def forward(self, input_ids, attention_mask=None):
|
| 234 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
|
| 14 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 15 |
from transformers import EsmModel, PreTrainedModel
|
| 16 |
from configuration import MetaLATTEConfig
|
| 17 |
+
from urllib.parse import urljoin
|
| 18 |
seed_everything(42)
|
| 19 |
|
| 20 |
class GELU(nn.Module):
|
|
|
|
| 226 |
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
| 227 |
|
| 228 |
model = cls(config)
|
| 229 |
+
#state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
| 230 |
+
try:
|
| 231 |
+
state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
|
| 232 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 233 |
+
state_dict_url,
|
| 234 |
+
map_location=torch.device('cpu')
|
| 235 |
+
)['state_dict']
|
| 236 |
+
model.load_state_dict(state_dict, strict=False)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
|
| 239 |
+
|
| 240 |
return model
|
| 241 |
+
|
| 242 |
|
| 243 |
def forward(self, input_ids, attention_mask=None):
|
| 244 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|