Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -720,8 +720,11 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 720 |
projected_bio_embeddings.append(proj)
|
| 721 |
for key in output.keys():
|
| 722 |
outs[f"{key}_{bio_seq_num}"] = output[key]
|
|
|
|
|
|
|
| 723 |
|
| 724 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
|
| 725 |
|
| 726 |
# decode
|
| 727 |
logits = self.biobrain_decoder(
|
|
@@ -730,13 +733,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 730 |
)
|
| 731 |
|
| 732 |
outs["logits"] = logits
|
| 733 |
-
outs["projected_bio_embeddings"] = projected_bio_embeddings
|
| 734 |
-
|
| 735 |
-
# Just for debugging
|
| 736 |
-
print("(debug) remember to remove bio_embeddings storage")
|
| 737 |
-
if projected_bio_embeddings is not None:
|
| 738 |
-
for i, embed in enumerate(bio_embeddings_list):
|
| 739 |
-
outs[f"bio_embeddings_list_{i}"] = embed
|
| 740 |
|
| 741 |
return outs
|
| 742 |
|
|
|
|
| 720 |
projected_bio_embeddings.append(proj)
|
| 721 |
for key in output.keys():
|
| 722 |
outs[f"{key}_{bio_seq_num}"] = output[key]
|
| 723 |
+
outs[f"bio_embeddings_list_{bio_seq_num}"] = proj
|
| 724 |
+
|
| 725 |
|
| 726 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 727 |
+
outs["projected_bio_embeddings"] = projected_bio_embeddings.clone()
|
| 728 |
|
| 729 |
# decode
|
| 730 |
logits = self.biobrain_decoder(
|
|
|
|
| 733 |
)
|
| 734 |
|
| 735 |
outs["logits"] = logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
|
| 737 |
return outs
|
| 738 |
|