HCAE-21M-v1.1-Base / modeling_hcae.py
HeavensHackDev's picture
Official Release HCAE v1.1 - HCAE-21M-v1.1-Base
b5990a7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_hcae import HCAEConfig
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None):
super().__init__()
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.w3 = nn.Linear(hidden_features, out_features or in_features)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5):
super().__init__()
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x * self.gamma
class ConvBlockV2(nn.Module):
def __init__(self, hidden_size, dropout=0.1):
super().__init__()
self.conv_dw = nn.Conv1d(hidden_size, hidden_size, 3, padding=1, groups=hidden_size)
self.conv_pw = nn.Conv1d(hidden_size, hidden_size, 1)
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.ffn = SwiGLU(hidden_size, int(hidden_size * 8 / 3))
self.ls1 = LayerScale(hidden_size)
self.ls2 = LayerScale(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attention_mask=None):
res = x
h = self.conv_pw(self.conv_dw(x.transpose(1, 2))).transpose(1, 2)
x = res + self.dropout(self.ls1(self.norm1(h)))
x = x + self.dropout(self.ls2(self.ffn(self.norm2(x))))
return x
class AttentionBlockV2(nn.Module):
def __init__(self, hidden_size, num_heads=12, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.qkv = nn.Linear(hidden_size, hidden_size * 3)
self.proj = nn.Linear(hidden_size, hidden_size)
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.ffn = SwiGLU(hidden_size, int(hidden_size * 8 / 3))
self.ls1 = LayerScale(hidden_size)
self.ls2 = LayerScale(hidden_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attention_mask=None):
B, L, D = x.size()
res = x
qkv = self.qkv(self.norm1(x)).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
m = None
if attention_mask is not None:
m = (1.0 - attention_mask[:, None, None, :].to(dtype=x.dtype)) * -10000.0
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=m, dropout_p=0.0)
x = res + self.dropout(self.ls1(self.proj(attn.permute(0, 2, 1, 3).reshape(B, L, D))))
x = x + self.dropout(self.ls2(self.ffn(self.norm2(x))))
return x
class HCAEPreTrainedModel(PreTrainedModel):
config_class = HCAEConfig
base_model_prefix = "model"
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class HCAEModel(HCAEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList(
[ConvBlockV2(config.hidden_size, config.dropout) for _ in range(config.conv_layers)] +
[AttentionBlockV2(config.hidden_size, config.num_heads, config.dropout) for _ in range(config.attn_layers)]
)
self.post_init()
def forward(self, input_ids, attention_mask=None):
pos = torch.arange(input_ids.size(1), device=input_ids.device)
x = self.dropout(self.LayerNorm(self.word_embeddings(input_ids) + self.position_embeddings(pos)))
for layer in self.layers:
x = layer(x, attention_mask=attention_mask)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
emb = (x * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
else:
emb = x.mean(1)
return F.normalize(emb, p=2, dim=1)