Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -694,6 +694,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 694 |
vocab_size - 1
|
| 695 |
)
|
| 696 |
|
|
|
|
| 697 |
if bio_token_ids is None:
|
| 698 |
projected_bio_embeddings = None
|
| 699 |
else:
|
|
@@ -708,14 +709,18 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 708 |
|
| 709 |
|
| 710 |
# Project these embeddings
|
| 711 |
-
projected_bio_embeddings = [
|
| 712 |
-
|
|
|
|
|
|
|
| 713 |
bio_token_ids=bio_token_ids[:, bio_seq_num],
|
| 714 |
bio_embeddings=bio_embeddings,
|
| 715 |
english_token_ids=projection_english_tokens_ids,
|
| 716 |
)
|
| 717 |
-
|
| 718 |
-
|
|
|
|
|
|
|
| 719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 720 |
|
| 721 |
# decode
|
|
@@ -724,7 +729,8 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 724 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 725 |
)
|
| 726 |
|
| 727 |
-
outs
|
|
|
|
| 728 |
|
| 729 |
# Just for debugging
|
| 730 |
print("(debug) remember to remove bio_embeddings storage")
|
|
@@ -1848,8 +1854,12 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1848 |
english_token_ids (torch.Tensor):
|
| 1849 |
Shape (batch_size, num_english_tokens)
|
| 1850 |
"""
|
|
|
|
| 1851 |
projected_bio_embeddings = self.bio_projection(bio_embeddings)
|
|
|
|
|
|
|
| 1852 |
english_embeddings = self.token_embedding(english_token_ids)
|
|
|
|
| 1853 |
|
| 1854 |
bio_attention_mask = build_perceiver_padding_attention_mask(
|
| 1855 |
bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
|
|
@@ -1865,7 +1875,7 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1865 |
attention_mask_2=english_attention_mask,
|
| 1866 |
)["embeddings"]
|
| 1867 |
|
| 1868 |
-
return projected_embeddings
|
| 1869 |
|
| 1870 |
|
| 1871 |
def build_perceiver_padding_attention_mask(
|
|
|
|
| 694 |
vocab_size - 1
|
| 695 |
)
|
| 696 |
|
| 697 |
+
outs = {}
|
| 698 |
if bio_token_ids is None:
|
| 699 |
projected_bio_embeddings = None
|
| 700 |
else:
|
|
|
|
| 709 |
|
| 710 |
|
| 711 |
# Project these embeddings
|
| 712 |
+
projected_bio_embeddings = []
|
| 713 |
+
print("(debug) remember to remove loop for projected")
|
| 714 |
+
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list):
|
| 715 |
+
proj, output = self.projection_model(
|
| 716 |
bio_token_ids=bio_token_ids[:, bio_seq_num],
|
| 717 |
bio_embeddings=bio_embeddings,
|
| 718 |
english_token_ids=projection_english_tokens_ids,
|
| 719 |
)
|
| 720 |
+
projected_bio_embeddings.append(proj)
|
| 721 |
+
for key in output.keys():
|
| 722 |
+
outs[f"{key}_{bio_seq_num}"] = output[key]
|
| 723 |
+
|
| 724 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 725 |
|
| 726 |
# decode
|
|
|
|
| 729 |
projected_bio_embeddings=projected_bio_embeddings,
|
| 730 |
)
|
| 731 |
|
| 732 |
+
outs["logits"] = logits
|
| 733 |
+
outs["projected_bio_embeddings"] = projected_bio_embeddings
|
| 734 |
|
| 735 |
# Just for debugging
|
| 736 |
print("(debug) remember to remove bio_embeddings storage")
|
|
|
|
| 1854 |
english_token_ids (torch.Tensor):
|
| 1855 |
Shape (batch_size, num_english_tokens)
|
| 1856 |
"""
|
| 1857 |
+
outs = {}
|
| 1858 |
projected_bio_embeddings = self.bio_projection(bio_embeddings)
|
| 1859 |
+
print("(debug) remember to remove this projected_bio_embeddings out, and 'outs' output")
|
| 1860 |
+
outs['projected_bio_embeddings'] = projected_bio_embeddings
|
| 1861 |
english_embeddings = self.token_embedding(english_token_ids)
|
| 1862 |
+
outs['english_embeddings'] = english_embeddings
|
| 1863 |
|
| 1864 |
bio_attention_mask = build_perceiver_padding_attention_mask(
|
| 1865 |
bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
|
|
|
|
| 1875 |
attention_mask_2=english_attention_mask,
|
| 1876 |
)["embeddings"]
|
| 1877 |
|
| 1878 |
+
return projected_embeddings, outs
|
| 1879 |
|
| 1880 |
|
| 1881 |
def build_perceiver_padding_attention_mask(
|