Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -1788,7 +1788,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1788 |
"""
|
| 1789 |
Computes the embeddings based on the input tokens.
|
| 1790 |
"""
|
| 1791 |
-
|
| 1792 |
assert (
|
| 1793 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
| 1794 |
), "The input embedding dim should match the model embed dim"
|
|
@@ -1803,7 +1803,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1803 |
outs: Dict[str, torch.Tensor] = {}
|
| 1804 |
x = latent_queries
|
| 1805 |
|
| 1806 |
-
|
| 1807 |
|
| 1808 |
x, outs = self.apply_attention_blocks(
|
| 1809 |
x=x,
|
|
@@ -1814,9 +1814,12 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1814 |
attention_mask_2=attention_mask_2,
|
| 1815 |
)
|
| 1816 |
|
|
|
|
|
|
|
|
|
|
| 1817 |
outs["embeddings"] = x
|
| 1818 |
|
| 1819 |
-
return outs
|
| 1820 |
|
| 1821 |
|
| 1822 |
class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
|
|
| 1788 |
"""
|
| 1789 |
Computes the embeddings based on the input tokens.
|
| 1790 |
"""
|
| 1791 |
+
new_outs = {}
|
| 1792 |
assert (
|
| 1793 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
| 1794 |
), "The input embedding dim should match the model embed dim"
|
|
|
|
| 1803 |
outs: Dict[str, torch.Tensor] = {}
|
| 1804 |
x = latent_queries
|
| 1805 |
|
| 1806 |
+
new_outs["latent_queries"] = x.clone()
|
| 1807 |
|
| 1808 |
x, outs = self.apply_attention_blocks(
|
| 1809 |
x=x,
|
|
|
|
| 1814 |
attention_mask_2=attention_mask_2,
|
| 1815 |
)
|
| 1816 |
|
| 1817 |
+
for key in outs.keys():
|
| 1818 |
+
new_outs[key] = outs[key].copy()
|
| 1819 |
+
|
| 1820 |
outs["embeddings"] = x
|
| 1821 |
|
| 1822 |
+
return outs, new_outs
|
| 1823 |
|
| 1824 |
|
| 1825 |
class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|