Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +2 -97
modeling_esm_plusplus.py
CHANGED
|
@@ -33,14 +33,11 @@ except ImportError:
|
|
| 33 |
flex_attention = None
|
| 34 |
|
| 35 |
try:
|
| 36 |
-
# when used from AutoModel, these are in the same directory
|
| 37 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 38 |
-
except:
|
| 39 |
try:
|
| 40 |
-
# whem importing as a submodule, embedding mixin is in the FastPLMs directory
|
| 41 |
from ..embedding_mixin import EmbeddingMixin, Pooler
|
| 42 |
-
except:
|
| 43 |
-
# when running from our repo, these are in the base directory
|
| 44 |
from embedding_mixin import EmbeddingMixin, Pooler
|
| 45 |
|
| 46 |
|
|
@@ -1142,95 +1139,3 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
|
|
| 1142 |
return self.all_special_ids
|
| 1143 |
|
| 1144 |
|
| 1145 |
-
if __name__ == "__main__":
|
| 1146 |
-
# Set device to CPU for testing
|
| 1147 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1148 |
-
print(f"Using device: {device}")
|
| 1149 |
-
|
| 1150 |
-
# Test tokenizer
|
| 1151 |
-
tokenizer = EsmSequenceTokenizer()
|
| 1152 |
-
sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
|
| 1153 |
-
encoding = tokenizer(sample_sequence, return_tensors="pt")
|
| 1154 |
-
print(f"Input sequence length: {len(sample_sequence)}")
|
| 1155 |
-
print(f"Tokenized sequence: {encoding['input_ids'].shape}")
|
| 1156 |
-
|
| 1157 |
-
# Prepare inputs
|
| 1158 |
-
input_ids = encoding['input_ids'].to(device)
|
| 1159 |
-
attention_mask = encoding['attention_mask'].to(device)
|
| 1160 |
-
|
| 1161 |
-
# Test base model with smaller config for quick testing
|
| 1162 |
-
print("\n=== Testing ESMplusplus Base Model ===")
|
| 1163 |
-
base_config = ESMplusplusConfig(
|
| 1164 |
-
hidden_size=384,
|
| 1165 |
-
num_attention_heads=6,
|
| 1166 |
-
num_hidden_layers=4
|
| 1167 |
-
)
|
| 1168 |
-
base_model = ESMplusplusModel(base_config).to(device)
|
| 1169 |
-
|
| 1170 |
-
with torch.no_grad():
|
| 1171 |
-
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1172 |
-
|
| 1173 |
-
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1174 |
-
|
| 1175 |
-
# Test embedding functionality
|
| 1176 |
-
print("\nTesting embedding functionality:")
|
| 1177 |
-
with torch.no_grad():
|
| 1178 |
-
embeddings = base_model._embed(input_ids, attention_mask)
|
| 1179 |
-
print(f"Embedding shape: {embeddings.shape}")
|
| 1180 |
-
|
| 1181 |
-
# Test masked language modeling
|
| 1182 |
-
print("\n=== Testing ESMplusplus For Masked LM ===")
|
| 1183 |
-
mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
|
| 1184 |
-
|
| 1185 |
-
with torch.no_grad():
|
| 1186 |
-
outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1187 |
-
|
| 1188 |
-
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1189 |
-
print(f"Logits shape: {outputs.logits.shape}")
|
| 1190 |
-
|
| 1191 |
-
# Test sequence classification model
|
| 1192 |
-
print("\n=== Testing Sequence Classification Model ===")
|
| 1193 |
-
classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
|
| 1194 |
-
|
| 1195 |
-
with torch.no_grad():
|
| 1196 |
-
outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1197 |
-
|
| 1198 |
-
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1199 |
-
print(f"Logits shape: {outputs.logits.shape}")
|
| 1200 |
-
|
| 1201 |
-
# Test token classification model
|
| 1202 |
-
print("\n=== Testing Token Classification Model ===")
|
| 1203 |
-
token_model = ESMplusplusForTokenClassification(base_config).to(device)
|
| 1204 |
-
|
| 1205 |
-
with torch.no_grad():
|
| 1206 |
-
outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1207 |
-
|
| 1208 |
-
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1209 |
-
print(f"Logits shape: {outputs.logits.shape}")
|
| 1210 |
-
|
| 1211 |
-
# Test embedding dataset functionality with a mini dataset
|
| 1212 |
-
print("\n=== Testing Embed Dataset Functionality ===")
|
| 1213 |
-
mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
|
| 1214 |
-
print(f"Creating embeddings for {len(mini_dataset)} sequences")
|
| 1215 |
-
|
| 1216 |
-
# Only run this if save path doesn't exist to avoid overwriting
|
| 1217 |
-
if not os.path.exists("test_embeddings.pth"):
|
| 1218 |
-
embeddings = mlm_model.embed_dataset(
|
| 1219 |
-
sequences=mini_dataset,
|
| 1220 |
-
tokenizer=tokenizer,
|
| 1221 |
-
batch_size=2,
|
| 1222 |
-
max_len=100,
|
| 1223 |
-
full_embeddings=False,
|
| 1224 |
-
pooling_types=['mean'],
|
| 1225 |
-
save_path="test_embeddings.pth"
|
| 1226 |
-
)
|
| 1227 |
-
if embeddings:
|
| 1228 |
-
print(f"Embedding dictionary size: {len(embeddings)}")
|
| 1229 |
-
for seq, emb in embeddings.items():
|
| 1230 |
-
print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
|
| 1231 |
-
break
|
| 1232 |
-
else:
|
| 1233 |
-
print("Skipping embedding test as test_embeddings.pth already exists")
|
| 1234 |
-
|
| 1235 |
-
print("\nAll tests completed successfully!")
|
| 1236 |
-
|
|
|
|
| 33 |
flex_attention = None
|
| 34 |
|
| 35 |
try:
|
|
|
|
| 36 |
from .embedding_mixin import EmbeddingMixin, Pooler
|
| 37 |
+
except ImportError:
|
| 38 |
try:
|
|
|
|
| 39 |
from ..embedding_mixin import EmbeddingMixin, Pooler
|
| 40 |
+
except ImportError:
|
|
|
|
| 41 |
from embedding_mixin import EmbeddingMixin, Pooler
|
| 42 |
|
| 43 |
|
|
|
|
| 1139 |
return self.all_special_ids
|
| 1140 |
|
| 1141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|