| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from .configuration_hcae import HCAEConfig |
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, in_features, hidden_features, out_features=None): |
| super().__init__() |
| self.w1 = nn.Linear(in_features, hidden_features) |
| self.w2 = nn.Linear(in_features, hidden_features) |
| self.w3 = nn.Linear(hidden_features, out_features or in_features) |
| def forward(self, x): |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
| class LayerScale(nn.Module): |
| def __init__(self, dim, init_values=1e-5): |
| super().__init__() |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
| def forward(self, x): |
| return x * self.gamma |
|
|
| class ConvBlockV2(nn.Module): |
| def __init__(self, hidden_size, dropout=0.1): |
| super().__init__() |
| self.conv_dw = nn.Conv1d(hidden_size, hidden_size, 3, padding=1, groups=hidden_size) |
| self.conv_pw = nn.Conv1d(hidden_size, hidden_size, 1) |
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.norm2 = nn.LayerNorm(hidden_size) |
| self.ffn = SwiGLU(hidden_size, int(hidden_size * 8 / 3)) |
| self.ls1 = LayerScale(hidden_size) |
| self.ls2 = LayerScale(hidden_size) |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, x, attention_mask=None): |
| res = x |
| h = self.conv_pw(self.conv_dw(x.transpose(1, 2))).transpose(1, 2) |
| x = res + self.dropout(self.ls1(self.norm1(h))) |
| x = x + self.dropout(self.ls2(self.ffn(self.norm2(x)))) |
| return x |
|
|
| class AttentionBlockV2(nn.Module): |
| def __init__(self, hidden_size, num_heads=12, dropout=0.1): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.qkv = nn.Linear(hidden_size, hidden_size * 3) |
| self.proj = nn.Linear(hidden_size, hidden_size) |
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.norm2 = nn.LayerNorm(hidden_size) |
| self.ffn = SwiGLU(hidden_size, int(hidden_size * 8 / 3)) |
| self.ls1 = LayerScale(hidden_size) |
| self.ls2 = LayerScale(hidden_size) |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, x, attention_mask=None): |
| B, L, D = x.size() |
| res = x |
| qkv = self.qkv(self.norm1(x)).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| m = None |
| if attention_mask is not None: |
| m = (1.0 - attention_mask[:, None, None, :].to(dtype=x.dtype)) * -10000.0 |
|
|
| attn = F.scaled_dot_product_attention(q, k, v, attn_mask=m, dropout_p=0.0) |
| x = res + self.dropout(self.ls1(self.proj(attn.permute(0, 2, 1, 3).reshape(B, L, D)))) |
| x = x + self.dropout(self.ls2(self.ffn(self.norm2(x)))) |
| return x |
|
|
| class HCAEPreTrainedModel(PreTrainedModel): |
| config_class = HCAEConfig |
| base_model_prefix = "model" |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| class HCAEModel(HCAEPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size) |
| self.dropout = nn.Dropout(config.dropout) |
| |
| self.layers = nn.ModuleList( |
| [ConvBlockV2(config.hidden_size, config.dropout) for _ in range(config.conv_layers)] + |
| [AttentionBlockV2(config.hidden_size, config.num_heads, config.dropout) for _ in range(config.attn_layers)] |
| ) |
| self.post_init() |
|
|
| def forward(self, input_ids, attention_mask=None): |
| pos = torch.arange(input_ids.size(1), device=input_ids.device) |
| x = self.dropout(self.LayerNorm(self.word_embeddings(input_ids) + self.position_embeddings(pos))) |
| for layer in self.layers: |
| x = layer(x, attention_mask=attention_mask) |
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).float() |
| emb = (x * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| else: |
| emb = x.mean(1) |
| |
| return F.normalize(emb, p=2, dim=1) |
|
|