Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +97 -4
modeling_esm_plusplus.py
CHANGED
|
@@ -931,6 +931,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 931 |
self.mse = nn.MSELoss()
|
| 932 |
self.ce = nn.CrossEntropyLoss()
|
| 933 |
self.bce = nn.BCEWithLogitsLoss()
|
|
|
|
| 934 |
self.init_weights()
|
| 935 |
|
| 936 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
@@ -969,10 +970,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 969 |
output_hidden_states=output_hidden_states
|
| 970 |
)
|
| 971 |
x = output.last_hidden_state
|
| 972 |
-
|
| 973 |
-
mean_features = self.mean_pooling(x, attention_mask)
|
| 974 |
-
# we include mean pooling features to help with early convergence, the cost of this is basically zero
|
| 975 |
-
features = torch.cat([cls_features, mean_features], dim=-1)
|
| 976 |
logits = self.classifier(features)
|
| 977 |
loss = None
|
| 978 |
if labels is not None:
|
|
@@ -994,6 +992,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 994 |
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
| 995 |
elif self.config.problem_type == "multi_label_classification":
|
| 996 |
loss = self.bce(logits, labels)
|
|
|
|
| 997 |
return ESMplusplusOutput(
|
| 998 |
loss=loss,
|
| 999 |
logits=logits,
|
|
@@ -1197,3 +1196,97 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
|
|
| 1197 |
@property
|
| 1198 |
def special_token_ids(self):
|
| 1199 |
return self.all_special_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
self.mse = nn.MSELoss()
|
| 932 |
self.ce = nn.CrossEntropyLoss()
|
| 933 |
self.bce = nn.BCEWithLogitsLoss()
|
| 934 |
+
self.pooler = Pooler(['cls','mean'])
|
| 935 |
self.init_weights()
|
| 936 |
|
| 937 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 970 |
output_hidden_states=output_hidden_states
|
| 971 |
)
|
| 972 |
x = output.last_hidden_state
|
| 973 |
+
features = self.pooler(x, attention_mask)
|
|
|
|
|
|
|
|
|
|
| 974 |
logits = self.classifier(features)
|
| 975 |
loss = None
|
| 976 |
if labels is not None:
|
|
|
|
| 992 |
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
|
| 993 |
elif self.config.problem_type == "multi_label_classification":
|
| 994 |
loss = self.bce(logits, labels)
|
| 995 |
+
|
| 996 |
return ESMplusplusOutput(
|
| 997 |
loss=loss,
|
| 998 |
logits=logits,
|
|
|
|
| 1196 |
@property
|
| 1197 |
def special_token_ids(self):
|
| 1198 |
return self.all_special_ids
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
if __name__ == "__main__":
|
| 1202 |
+
# Set device to CPU for testing
|
| 1203 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1204 |
+
print(f"Using device: {device}")
|
| 1205 |
+
|
| 1206 |
+
# Test tokenizer
|
| 1207 |
+
tokenizer = EsmSequenceTokenizer()
|
| 1208 |
+
sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
|
| 1209 |
+
encoding = tokenizer(sample_sequence, return_tensors="pt")
|
| 1210 |
+
print(f"Input sequence length: {len(sample_sequence)}")
|
| 1211 |
+
print(f"Tokenized sequence: {encoding['input_ids'].shape}")
|
| 1212 |
+
|
| 1213 |
+
# Prepare inputs
|
| 1214 |
+
input_ids = encoding['input_ids'].to(device)
|
| 1215 |
+
attention_mask = encoding['attention_mask'].to(device)
|
| 1216 |
+
|
| 1217 |
+
# Test base model with smaller config for quick testing
|
| 1218 |
+
print("\n=== Testing ESMplusplus Base Model ===")
|
| 1219 |
+
base_config = ESMplusplusConfig(
|
| 1220 |
+
hidden_size=384,
|
| 1221 |
+
num_attention_heads=6,
|
| 1222 |
+
num_hidden_layers=4
|
| 1223 |
+
)
|
| 1224 |
+
base_model = ESMplusplusModel(base_config).to(device)
|
| 1225 |
+
|
| 1226 |
+
with torch.no_grad():
|
| 1227 |
+
outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1228 |
+
|
| 1229 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1230 |
+
|
| 1231 |
+
# Test embedding functionality
|
| 1232 |
+
print("\nTesting embedding functionality:")
|
| 1233 |
+
with torch.no_grad():
|
| 1234 |
+
embeddings = base_model._embed(input_ids, attention_mask)
|
| 1235 |
+
print(f"Embedding shape: {embeddings.shape}")
|
| 1236 |
+
|
| 1237 |
+
# Test masked language modeling
|
| 1238 |
+
print("\n=== Testing ESMplusplus For Masked LM ===")
|
| 1239 |
+
mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
|
| 1240 |
+
|
| 1241 |
+
with torch.no_grad():
|
| 1242 |
+
outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1243 |
+
|
| 1244 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1245 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
| 1246 |
+
|
| 1247 |
+
# Test sequence classification model
|
| 1248 |
+
print("\n=== Testing Sequence Classification Model ===")
|
| 1249 |
+
classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
|
| 1250 |
+
|
| 1251 |
+
with torch.no_grad():
|
| 1252 |
+
outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1253 |
+
|
| 1254 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1255 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
| 1256 |
+
|
| 1257 |
+
# Test token classification model
|
| 1258 |
+
print("\n=== Testing Token Classification Model ===")
|
| 1259 |
+
token_model = ESMplusplusForTokenClassification(base_config).to(device)
|
| 1260 |
+
|
| 1261 |
+
with torch.no_grad():
|
| 1262 |
+
outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 1263 |
+
|
| 1264 |
+
print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
|
| 1265 |
+
print(f"Logits shape: {outputs.logits.shape}")
|
| 1266 |
+
|
| 1267 |
+
# Test embedding dataset functionality with a mini dataset
|
| 1268 |
+
print("\n=== Testing Embed Dataset Functionality ===")
|
| 1269 |
+
mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
|
| 1270 |
+
print(f"Creating embeddings for {len(mini_dataset)} sequences")
|
| 1271 |
+
|
| 1272 |
+
# Only run this if save path doesn't exist to avoid overwriting
|
| 1273 |
+
if not os.path.exists("test_embeddings.pth"):
|
| 1274 |
+
embeddings = mlm_model.embed_dataset(
|
| 1275 |
+
sequences=mini_dataset,
|
| 1276 |
+
tokenizer=tokenizer,
|
| 1277 |
+
batch_size=2,
|
| 1278 |
+
max_len=100,
|
| 1279 |
+
full_embeddings=False,
|
| 1280 |
+
pooling_types=['mean'],
|
| 1281 |
+
save_path="test_embeddings.pth"
|
| 1282 |
+
)
|
| 1283 |
+
if embeddings:
|
| 1284 |
+
print(f"Embedding dictionary size: {len(embeddings)}")
|
| 1285 |
+
for seq, emb in embeddings.items():
|
| 1286 |
+
print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
|
| 1287 |
+
break
|
| 1288 |
+
else:
|
| 1289 |
+
print("Skipping embedding test as test_embeddings.pth already exists")
|
| 1290 |
+
|
| 1291 |
+
print("\nAll tests completed successfully!")
|
| 1292 |
+
|