import math import torch import torch.nn as nn from transformers import BertConfig from torch.utils.checkpoint import checkpoint class ConvBlock(nn.Module): def __init__(self, hidden_size, kernel_size=3, padding=1): super().__init__() self.conv_dw = nn.Conv1d( in_channels=hidden_size, out_channels=hidden_size, kernel_size=kernel_size, padding=padding, groups=hidden_size ) self.conv_pw = nn.Conv1d( in_channels=hidden_size, out_channels=hidden_size, kernel_size=1 ) self.norm1 = nn.LayerNorm(hidden_size) self.ffn = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.GELU(), nn.Linear(hidden_size * 4, hidden_size) ) self.norm2 = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(0.1) def forward(self, x): residual = x x_conv = x.transpose(1, 2) x_conv = self.conv_dw(x_conv) x_conv = self.conv_pw(x_conv) x_conv = x_conv.transpose(1, 2) x = self.norm1(residual + self.dropout(x_conv)) residual = x x_ffn = self.ffn(x) x = self.norm2(residual + self.dropout(x_ffn)) return x class AttentionBlock(nn.Module): def __init__(self, hidden_size, num_heads): super().__init__() self.self_attn = nn.MultiheadAttention( embed_dim=hidden_size, num_heads=num_heads, dropout=0.1, batch_first=True ) self.norm1 = nn.LayerNorm(hidden_size) self.ffn = nn.Sequential( nn.Linear(hidden_size, hidden_size * 4), nn.GELU(), nn.Linear(hidden_size * 4, hidden_size) ) self.norm2 = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(0.1) def forward(self, x, attention_mask=None): residual = x if attention_mask is not None: key_padding_mask = (attention_mask == 0) else: key_padding_mask = None attn_output, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=key_padding_mask, need_weights=False ) x = self.norm1(residual + self.dropout(attn_output)) residual = x x_ffn = self.ffn(x) x = self.norm2(residual + self.dropout(x_ffn)) return x class HCAEModel(nn.Module): def __init__(self, vocab_size=30522, hidden_size=384, max_seq_len=512, conv_layers=5, attn_layers=3, num_heads=12): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0) self.position_embeddings = nn.Embedding(max_seq_len, hidden_size) self.LayerNorm = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(0.1) self.conv_blocks = nn.ModuleList([ ConvBlock(hidden_size) for _ in range(conv_layers) ]) self.attn_blocks = nn.ModuleList([ AttentionBlock(hidden_size, num_heads) for _ in range(attn_layers) ]) self.use_gradient_checkpointing = False def forward(self, input_ids, attention_mask=None): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) x = words_embeddings + position_embeddings x = self.LayerNorm(x) x = self.dropout(x) for i, block in enumerate(self.conv_blocks): if self.use_gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*args): return module(*args) return custom_forward x = checkpoint(create_custom_forward(block), x, use_reentrant=False) else: x = block(x) for i, block in enumerate(self.attn_blocks): if self.use_gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(hidden_states, mask): return module(hidden_states, attention_mask=mask) return custom_forward x = checkpoint(create_custom_forward(block), x, attention_mask, use_reentrant=False) else: x = block(x, attention_mask=attention_mask) if attention_mask is not None: input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float() sum_embeddings = torch.sum(x * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) sentence_embeddings = sum_embeddings / sum_mask else: sentence_embeddings = x.mean(dim=1) return sentence_embeddings if __name__ == "__main__": model = HCAEModel() total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params / 1e6:.2f} M") batch_size = 32 seq_len = 128 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) dummy_input = torch.randint(0, 30522, (batch_size, seq_len)).to(device) dummy_mask = torch.ones((batch_size, seq_len)).to(device) model.use_gradient_checkpointing = True with torch.cuda.amp.autocast(dtype=torch.float16): output = model(dummy_input, attention_mask=dummy_mask) print(f"Output shape: {output.shape}") if torch.cuda.is_available(): memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 2) memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 2) print(f"CUDA memory allocated: {memory_allocated:.2f} MB") print(f"CUDA memory reserved: {memory_reserved:.2f} MB")