| import math |
| import torch |
| import torch.nn as nn |
| from transformers import BertConfig |
| from torch.utils.checkpoint import checkpoint |
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, hidden_size, kernel_size=3, padding=1): |
| super().__init__() |
| self.conv_dw = nn.Conv1d( |
| in_channels=hidden_size, |
| out_channels=hidden_size, |
| kernel_size=kernel_size, |
| padding=padding, |
| groups=hidden_size |
| ) |
| self.conv_pw = nn.Conv1d( |
| in_channels=hidden_size, |
| out_channels=hidden_size, |
| kernel_size=1 |
| ) |
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.ffn = nn.Sequential( |
| nn.Linear(hidden_size, hidden_size * 4), |
| nn.GELU(), |
| nn.Linear(hidden_size * 4, hidden_size) |
| ) |
| self.norm2 = nn.LayerNorm(hidden_size) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward(self, x): |
| residual = x |
| x_conv = x.transpose(1, 2) |
| x_conv = self.conv_dw(x_conv) |
| x_conv = self.conv_pw(x_conv) |
| x_conv = x_conv.transpose(1, 2) |
| x = self.norm1(residual + self.dropout(x_conv)) |
| residual = x |
| x_ffn = self.ffn(x) |
| x = self.norm2(residual + self.dropout(x_ffn)) |
| return x |
|
|
|
|
| class AttentionBlock(nn.Module): |
| def __init__(self, hidden_size, num_heads): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| num_heads=num_heads, |
| dropout=0.1, |
| batch_first=True |
| ) |
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.ffn = nn.Sequential( |
| nn.Linear(hidden_size, hidden_size * 4), |
| nn.GELU(), |
| nn.Linear(hidden_size * 4, hidden_size) |
| ) |
| self.norm2 = nn.LayerNorm(hidden_size) |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward(self, x, attention_mask=None): |
| residual = x |
| if attention_mask is not None: |
| key_padding_mask = (attention_mask == 0) |
| else: |
| key_padding_mask = None |
| |
| attn_output, _ = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=key_padding_mask, |
| need_weights=False |
| ) |
| x = self.norm1(residual + self.dropout(attn_output)) |
| residual = x |
| x_ffn = self.ffn(x) |
| x = self.norm2(residual + self.dropout(x_ffn)) |
| return x |
|
|
|
|
| class HCAEModel(nn.Module): |
| def __init__(self, vocab_size=30522, hidden_size=384, max_seq_len=512, |
| conv_layers=5, attn_layers=3, num_heads=12): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) |
| self.position_embeddings = nn.Embedding(max_seq_len, hidden_size) |
| self.LayerNorm = nn.LayerNorm(hidden_size) |
| self.dropout = nn.Dropout(0.1) |
| self.conv_blocks = nn.ModuleList([ |
| ConvBlock(hidden_size) for _ in range(conv_layers) |
| ]) |
| self.attn_blocks = nn.ModuleList([ |
| AttentionBlock(hidden_size, num_heads) for _ in range(attn_layers) |
| ]) |
| self.use_gradient_checkpointing = False |
|
|
| def forward(self, input_ids, attention_mask=None): |
| seq_length = input_ids.size(1) |
| position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) |
| |
| words_embeddings = self.word_embeddings(input_ids) |
| position_embeddings = self.position_embeddings(position_ids) |
| x = words_embeddings + position_embeddings |
| x = self.LayerNorm(x) |
| x = self.dropout(x) |
| |
| for i, block in enumerate(self.conv_blocks): |
| if self.use_gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(*args): |
| return module(*args) |
| return custom_forward |
| x = checkpoint(create_custom_forward(block), x, use_reentrant=False) |
| else: |
| x = block(x) |
| |
| for i, block in enumerate(self.attn_blocks): |
| if self.use_gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(hidden_states, mask): |
| return module(hidden_states, attention_mask=mask) |
| return custom_forward |
| x = checkpoint(create_custom_forward(block), x, attention_mask, use_reentrant=False) |
| else: |
| x = block(x, attention_mask=attention_mask) |
| |
| if attention_mask is not None: |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float() |
| sum_embeddings = torch.sum(x * input_mask_expanded, 1) |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| sentence_embeddings = sum_embeddings / sum_mask |
| else: |
| sentence_embeddings = x.mean(dim=1) |
| |
| return sentence_embeddings |
|
|
| if __name__ == "__main__": |
| model = HCAEModel() |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Total parameters: {total_params / 1e6:.2f} M") |
| |
| batch_size = 32 |
| seq_len = 128 |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model.to(device) |
| |
| dummy_input = torch.randint(0, 30522, (batch_size, seq_len)).to(device) |
| dummy_mask = torch.ones((batch_size, seq_len)).to(device) |
| |
| model.use_gradient_checkpointing = True |
| |
| with torch.cuda.amp.autocast(dtype=torch.float16): |
| output = model(dummy_input, attention_mask=dummy_mask) |
| |
| print(f"Output shape: {output.shape}") |
| |
| if torch.cuda.is_available(): |
| memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 2) |
| memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 2) |
| print(f"CUDA memory allocated: {memory_allocated:.2f} MB") |
| print(f"CUDA memory reserved: {memory_reserved:.2f} MB") |
|
|