Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -1763,6 +1763,9 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
| 1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
| 1765 |
|
|
|
|
|
|
|
|
|
|
| 1766 |
output = layer(
|
| 1767 |
x=x,
|
| 1768 |
cross_attention_embeddings_1=concat_input_1,
|
|
@@ -1771,6 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1771 |
attention_mask_2=attention_mask_2,
|
| 1772 |
)
|
| 1773 |
x = output["embeddings"]
|
|
|
|
| 1774 |
|
| 1775 |
return x, outs
|
| 1776 |
|
|
@@ -1784,6 +1788,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1784 |
"""
|
| 1785 |
Computes the embeddings based on the input tokens.
|
| 1786 |
"""
|
|
|
|
| 1787 |
assert (
|
| 1788 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
| 1789 |
), "The input embedding dim should match the model embed dim"
|
|
@@ -1798,6 +1803,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1798 |
outs: Dict[str, torch.Tensor] = {}
|
| 1799 |
x = latent_queries
|
| 1800 |
|
|
|
|
|
|
|
| 1801 |
x, outs = self.apply_attention_blocks(
|
| 1802 |
x=x,
|
| 1803 |
xf_1=input_embeddings_1,
|
|
@@ -1865,13 +1872,17 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1865 |
english_token_ids, self.config.resampled_length, self.english_pad_token_id
|
| 1866 |
)
|
| 1867 |
|
| 1868 |
-
projected_embeddings = self.perceiver_resampler(
|
| 1869 |
input_embeddings_1=projected_bio_embeddings,
|
| 1870 |
attention_mask_1=bio_attention_mask,
|
| 1871 |
input_embeddings_2=english_embeddings,
|
| 1872 |
attention_mask_2=english_attention_mask,
|
| 1873 |
-
)
|
|
|
|
| 1874 |
|
|
|
|
|
|
|
|
|
|
| 1875 |
return projected_embeddings, outs
|
| 1876 |
|
| 1877 |
|
|
|
|
| 1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
| 1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
| 1765 |
|
| 1766 |
+
outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
|
| 1767 |
+
outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
|
| 1768 |
+
|
| 1769 |
output = layer(
|
| 1770 |
x=x,
|
| 1771 |
cross_attention_embeddings_1=concat_input_1,
|
|
|
|
| 1774 |
attention_mask_2=attention_mask_2,
|
| 1775 |
)
|
| 1776 |
x = output["embeddings"]
|
| 1777 |
+
outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
|
| 1778 |
|
| 1779 |
return x, outs
|
| 1780 |
|
|
|
|
| 1788 |
"""
|
| 1789 |
Computes the embeddings based on the input tokens.
|
| 1790 |
"""
|
| 1791 |
+
outs = {}
|
| 1792 |
assert (
|
| 1793 |
input_embeddings_1.shape[-1] == self.config.embed_dim
|
| 1794 |
), "The input embedding dim should match the model embed dim"
|
|
|
|
| 1803 |
outs: Dict[str, torch.Tensor] = {}
|
| 1804 |
x = latent_queries
|
| 1805 |
|
| 1806 |
+
outs["latent_queries"] = x.clone()
|
| 1807 |
+
|
| 1808 |
x, outs = self.apply_attention_blocks(
|
| 1809 |
x=x,
|
| 1810 |
xf_1=input_embeddings_1,
|
|
|
|
| 1872 |
english_token_ids, self.config.resampled_length, self.english_pad_token_id
|
| 1873 |
)
|
| 1874 |
|
| 1875 |
+
projected_embeddings, new_outs = self.perceiver_resampler(
|
| 1876 |
input_embeddings_1=projected_bio_embeddings,
|
| 1877 |
attention_mask_1=bio_attention_mask,
|
| 1878 |
input_embeddings_2=english_embeddings,
|
| 1879 |
attention_mask_2=english_attention_mask,
|
| 1880 |
+
)
|
| 1881 |
+
projected_embeddings = projected_embeddings["embeddings"]
|
| 1882 |
|
| 1883 |
+
for key in new_outs.keys():
|
| 1884 |
+
outs[f"{key}_perceiver"] = new_outs[key]
|
| 1885 |
+
|
| 1886 |
return projected_embeddings, outs
|
| 1887 |
|
| 1888 |
|