lhallee commited on
Commit
85e3cbd
·
verified ·
1 Parent(s): d1cc4f8

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +2 -97
modeling_esm_plusplus.py CHANGED
@@ -33,14 +33,11 @@ except ImportError:
33
  flex_attention = None
34
 
35
  try:
36
- # when used from AutoModel, these are in the same directory
37
  from .embedding_mixin import EmbeddingMixin, Pooler
38
- except:
39
  try:
40
- # whem importing as a submodule, embedding mixin is in the FastPLMs directory
41
  from ..embedding_mixin import EmbeddingMixin, Pooler
42
- except:
43
- # when running from our repo, these are in the base directory
44
  from embedding_mixin import EmbeddingMixin, Pooler
45
 
46
 
@@ -1142,95 +1139,3 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1142
  return self.all_special_ids
1143
 
1144
 
1145
- if __name__ == "__main__":
1146
- # Set device to CPU for testing
1147
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1148
- print(f"Using device: {device}")
1149
-
1150
- # Test tokenizer
1151
- tokenizer = EsmSequenceTokenizer()
1152
- sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1153
- encoding = tokenizer(sample_sequence, return_tensors="pt")
1154
- print(f"Input sequence length: {len(sample_sequence)}")
1155
- print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1156
-
1157
- # Prepare inputs
1158
- input_ids = encoding['input_ids'].to(device)
1159
- attention_mask = encoding['attention_mask'].to(device)
1160
-
1161
- # Test base model with smaller config for quick testing
1162
- print("\n=== Testing ESMplusplus Base Model ===")
1163
- base_config = ESMplusplusConfig(
1164
- hidden_size=384,
1165
- num_attention_heads=6,
1166
- num_hidden_layers=4
1167
- )
1168
- base_model = ESMplusplusModel(base_config).to(device)
1169
-
1170
- with torch.no_grad():
1171
- outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1172
-
1173
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1174
-
1175
- # Test embedding functionality
1176
- print("\nTesting embedding functionality:")
1177
- with torch.no_grad():
1178
- embeddings = base_model._embed(input_ids, attention_mask)
1179
- print(f"Embedding shape: {embeddings.shape}")
1180
-
1181
- # Test masked language modeling
1182
- print("\n=== Testing ESMplusplus For Masked LM ===")
1183
- mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1184
-
1185
- with torch.no_grad():
1186
- outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1187
-
1188
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1189
- print(f"Logits shape: {outputs.logits.shape}")
1190
-
1191
- # Test sequence classification model
1192
- print("\n=== Testing Sequence Classification Model ===")
1193
- classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1194
-
1195
- with torch.no_grad():
1196
- outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1197
-
1198
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1199
- print(f"Logits shape: {outputs.logits.shape}")
1200
-
1201
- # Test token classification model
1202
- print("\n=== Testing Token Classification Model ===")
1203
- token_model = ESMplusplusForTokenClassification(base_config).to(device)
1204
-
1205
- with torch.no_grad():
1206
- outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1207
-
1208
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1209
- print(f"Logits shape: {outputs.logits.shape}")
1210
-
1211
- # Test embedding dataset functionality with a mini dataset
1212
- print("\n=== Testing Embed Dataset Functionality ===")
1213
- mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1214
- print(f"Creating embeddings for {len(mini_dataset)} sequences")
1215
-
1216
- # Only run this if save path doesn't exist to avoid overwriting
1217
- if not os.path.exists("test_embeddings.pth"):
1218
- embeddings = mlm_model.embed_dataset(
1219
- sequences=mini_dataset,
1220
- tokenizer=tokenizer,
1221
- batch_size=2,
1222
- max_len=100,
1223
- full_embeddings=False,
1224
- pooling_types=['mean'],
1225
- save_path="test_embeddings.pth"
1226
- )
1227
- if embeddings:
1228
- print(f"Embedding dictionary size: {len(embeddings)}")
1229
- for seq, emb in embeddings.items():
1230
- print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1231
- break
1232
- else:
1233
- print("Skipping embedding test as test_embeddings.pth already exists")
1234
-
1235
- print("\nAll tests completed successfully!")
1236
-
 
33
  flex_attention = None
34
 
35
  try:
 
36
  from .embedding_mixin import EmbeddingMixin, Pooler
37
+ except ImportError:
38
  try:
 
39
  from ..embedding_mixin import EmbeddingMixin, Pooler
40
+ except ImportError:
 
41
  from embedding_mixin import EmbeddingMixin, Pooler
42
 
43
 
 
1139
  return self.all_special_ids
1140
 
1141