Upload model
Browse files- modeling_t5mimo.py +64 -74
modeling_t5mimo.py
CHANGED
|
@@ -125,6 +125,69 @@ class T5LayerFF(nn.Module):
|
|
| 125 |
return hidden_states
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
class T5Attention(nn.Module):
|
| 129 |
def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
|
| 130 |
super().__init__()
|
|
@@ -1265,7 +1328,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
| 1265 |
self.decoder = T5Stack(decoder_config, self.shared)
|
| 1266 |
|
| 1267 |
|
| 1268 |
-
self.conv_block = MultivariateConvBlock(config
|
| 1269 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 1270 |
|
| 1271 |
# Initialize weights and apply final processing
|
|
@@ -1676,76 +1739,3 @@ class T5MIMOEncoderModel(T5PreTrainedModel):
|
|
| 1676 |
|
| 1677 |
|
| 1678 |
|
| 1679 |
-
|
| 1680 |
-
class MultivariateConvBlock(nn.Module):
|
| 1681 |
-
def __init__(self, num_seqs, model_dim, kernel_size=3, num_filters=64, stride=1, padding=1):
|
| 1682 |
-
"""
|
| 1683 |
-
Multivariate convolutional block to capture cross-sequence interactions and temporal patterns.
|
| 1684 |
-
|
| 1685 |
-
Args:
|
| 1686 |
-
num_seqs (int): Number of sequences (multivariate time series).
|
| 1687 |
-
model_dim (int): Dimension of each feature vector (typically 256).
|
| 1688 |
-
kernel_size (int): Size of the convolutional kernel. Default is 3.
|
| 1689 |
-
num_filters (int): Number of convolutional filters (output channels). Default is 64.
|
| 1690 |
-
stride (int): Stride of the convolutional kernel. Default is 1.
|
| 1691 |
-
padding (int): Padding for the convolutional kernel. Default is 1 (to preserve sequence length).
|
| 1692 |
-
"""
|
| 1693 |
-
super(MultivariateConvBlock, self).__init__()
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
# 2D Convolution across sequences and time
|
| 1697 |
-
self.conv1 = nn.Conv2d(
|
| 1698 |
-
in_channels=num_seqs,
|
| 1699 |
-
out_channels=num_filters,
|
| 1700 |
-
kernel_size=kernel_size, # Kernel spans across time and all features
|
| 1701 |
-
stride=1, # Stride across time, no stride across features
|
| 1702 |
-
padding=1 # Padding to preserve sequence length, no padding across features
|
| 1703 |
-
)
|
| 1704 |
-
|
| 1705 |
-
# Batch normalization for stabilization and faster convergence
|
| 1706 |
-
self.bn1 = nn.BatchNorm2d(num_filters)
|
| 1707 |
-
|
| 1708 |
-
# Second convolution layer to further model interactions and temporal patterns
|
| 1709 |
-
self.conv2 = nn.Conv2d(
|
| 1710 |
-
in_channels=num_filters,
|
| 1711 |
-
out_channels=num_filters,
|
| 1712 |
-
kernel_size=(kernel_size, 1), # Focus only on temporal patterns
|
| 1713 |
-
stride=(stride, 1),
|
| 1714 |
-
padding=(padding, 0)
|
| 1715 |
-
)
|
| 1716 |
-
|
| 1717 |
-
# Batch normalization after second convolution
|
| 1718 |
-
self.bn2 = nn.BatchNorm2d(num_filters)
|
| 1719 |
-
|
| 1720 |
-
# 1x1 Convolution to reduce the channel dimension back to num_seqs
|
| 1721 |
-
self.conv3 = nn.Conv2d(
|
| 1722 |
-
in_channels=num_filters,
|
| 1723 |
-
out_channels=num_seqs, # Back to the original number of sequences (channels)
|
| 1724 |
-
kernel_size=(1, 1)
|
| 1725 |
-
)
|
| 1726 |
-
|
| 1727 |
-
def forward(self, x):
|
| 1728 |
-
"""
|
| 1729 |
-
Forward pass of the multivariate convolutional block.
|
| 1730 |
-
|
| 1731 |
-
Args:
|
| 1732 |
-
x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
|
| 1733 |
-
|
| 1734 |
-
Returns:
|
| 1735 |
-
torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
|
| 1736 |
-
"""
|
| 1737 |
-
# Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
|
| 1738 |
-
x = x.permute(0, 1, 3, 2)
|
| 1739 |
-
|
| 1740 |
-
# Apply first convolution and activation
|
| 1741 |
-
x = nn.functional.relu(self.bn1(self.conv1(x)))
|
| 1742 |
-
# Apply second convolution and activation
|
| 1743 |
-
x = nn.functional.relu(self.bn2(self.conv2(x)))
|
| 1744 |
-
|
| 1745 |
-
# Reduce channel dimension back to num_seqs
|
| 1746 |
-
x = self.conv3(x)
|
| 1747 |
-
|
| 1748 |
-
# Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
|
| 1749 |
-
x = x.permute(0, 1, 3, 2)
|
| 1750 |
-
|
| 1751 |
-
return x
|
|
|
|
| 125 |
return hidden_states
|
| 126 |
|
| 127 |
|
| 128 |
+
|
| 129 |
+
class MultivariateConvBlock(nn.Module):
|
| 130 |
+
def __init__(self, config: T5MIMOConfig, kernel_size=3, stride=1, padding=1):
|
| 131 |
+
super().__init__()
|
| 132 |
+
# 2D Convolution across sequences and time
|
| 133 |
+
self.conv1 = nn.Conv2d(
|
| 134 |
+
in_channels=config.num_seqs,
|
| 135 |
+
out_channels=config.num_filters,
|
| 136 |
+
kernel_size=kernel_size, # Kernel spans across time and all features
|
| 137 |
+
stride=1, # Stride across time, no stride across features
|
| 138 |
+
padding=1 # Padding to preserve sequence length, no padding across features
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Batch normalization for stabilization and faster convergence
|
| 142 |
+
self.bn1 = nn.BatchNorm2d(config.num_filters)
|
| 143 |
+
|
| 144 |
+
# Second convolution layer to further model interactions and temporal patterns
|
| 145 |
+
self.conv2 = nn.Conv2d(
|
| 146 |
+
in_channels=config.num_filters,
|
| 147 |
+
out_channels=config.num_filters,
|
| 148 |
+
kernel_size=(kernel_size, 1), # Focus only on temporal patterns
|
| 149 |
+
stride=(stride, 1),
|
| 150 |
+
padding=(padding, 0)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Batch normalization after second convolution
|
| 154 |
+
self.bn2 = nn.BatchNorm2d(config.num_filters)
|
| 155 |
+
|
| 156 |
+
# 1x1 Convolution to reduce the channel dimension back to num_seqs
|
| 157 |
+
self.conv3 = nn.Conv2d(
|
| 158 |
+
in_channels=config.num_filters,
|
| 159 |
+
out_channels=config.num_seqs, # Back to the original number of sequences (channels)
|
| 160 |
+
kernel_size=(1, 1)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
"""
|
| 165 |
+
Forward pass of the multivariate convolutional block.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
|
| 172 |
+
"""
|
| 173 |
+
# Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
|
| 174 |
+
x = x.permute(0, 1, 3, 2)
|
| 175 |
+
|
| 176 |
+
# Apply first convolution and activation
|
| 177 |
+
x = nn.functional.relu(self.bn1(self.conv1(x)))
|
| 178 |
+
# Apply second convolution and activation
|
| 179 |
+
x = nn.functional.relu(self.bn2(self.conv2(x)))
|
| 180 |
+
|
| 181 |
+
# Reduce channel dimension back to num_seqs
|
| 182 |
+
x = self.conv3(x)
|
| 183 |
+
|
| 184 |
+
# Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
|
| 185 |
+
x = x.permute(0, 1, 3, 2)
|
| 186 |
+
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
class T5Attention(nn.Module):
|
| 192 |
def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
|
| 193 |
super().__init__()
|
|
|
|
| 1328 |
self.decoder = T5Stack(decoder_config, self.shared)
|
| 1329 |
|
| 1330 |
|
| 1331 |
+
self.conv_block = MultivariateConvBlock(config)
|
| 1332 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 1333 |
|
| 1334 |
# Initialize weights and apply final processing
|
|
|
|
| 1739 |
|
| 1740 |
|
| 1741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|