Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -1661,14 +1661,24 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
| 1661 |
)
|
| 1662 |
|
| 1663 |
def mlp(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 1664 |
x = self.norm_mlp(x)
|
|
|
|
| 1665 |
if self.use_glu_in_ffn:
|
| 1666 |
x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
|
| 1667 |
x = self.activation_fn(x1) * x2
|
| 1668 |
else:
|
| 1669 |
-
x = self.
|
| 1670 |
-
|
|
|
|
|
|
|
| 1671 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1672 |
def forward(
|
| 1673 |
self,
|
| 1674 |
x: torch.Tensor,
|
|
@@ -1703,7 +1713,8 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
| 1703 |
outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
|
| 1704 |
x = res + attn_output
|
| 1705 |
|
| 1706 |
-
|
|
|
|
| 1707 |
outs_news["ATTENTION_after_mlp"] = x.clone()
|
| 1708 |
|
| 1709 |
output = {}
|
|
|
|
| 1661 |
)
|
| 1662 |
|
| 1663 |
def mlp(self, x: torch.Tensor) -> torch.Tensor:
|
| 1664 |
+
outs = {}
|
| 1665 |
x = self.norm_mlp(x)
|
| 1666 |
+
outs["MLP_layer0_layer_norm"] = x.clone()
|
| 1667 |
if self.use_glu_in_ffn:
|
| 1668 |
x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
|
| 1669 |
x = self.activation_fn(x1) * x2
|
| 1670 |
else:
|
| 1671 |
+
x = self.fc1(x)
|
| 1672 |
+
outs["MLP_layer1_fc1"] = x.clone()
|
| 1673 |
+
x = self.activation_fn(x)
|
| 1674 |
+
outs["MLP_layer2_activation"] = x.clone()
|
| 1675 |
|
| 1676 |
+
x = self.fc2(x)
|
| 1677 |
+
outs["MLP_layer3_fc2"] = x.clone()
|
| 1678 |
+
outs["x"] = x.clone()
|
| 1679 |
+
|
| 1680 |
+
return outs
|
| 1681 |
+
|
| 1682 |
def forward(
|
| 1683 |
self,
|
| 1684 |
x: torch.Tensor,
|
|
|
|
| 1713 |
outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
|
| 1714 |
x = res + attn_output
|
| 1715 |
|
| 1716 |
+
mlp_output = self.mlp(x)
|
| 1717 |
+
x = x + mlp_output["x"]
|
| 1718 |
outs_news["ATTENTION_after_mlp"] = x.clone()
|
| 1719 |
|
| 1720 |
output = {}
|