HCAE-21M-Instruct / model.py
HeavensHackDev's picture
Upload model.py with huggingface_hub
bb7acef verified
import math
import torch
import torch.nn as nn
from transformers import BertConfig
from torch.utils.checkpoint import checkpoint
class ConvBlock(nn.Module):
def __init__(self, hidden_size, kernel_size=3, padding=1):
super().__init__()
self.conv_dw = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=kernel_size,
padding=padding,
groups=hidden_size
)
self.conv_pw = nn.Conv1d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=1
)
self.norm1 = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
residual = x
x_conv = x.transpose(1, 2)
x_conv = self.conv_dw(x_conv)
x_conv = self.conv_pw(x_conv)
x_conv = x_conv.transpose(1, 2)
x = self.norm1(residual + self.dropout(x_conv))
residual = x
x_ffn = self.ffn(x)
x = self.norm2(residual + self.dropout(x_ffn))
return x
class AttentionBlock(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=hidden_size,
num_heads=num_heads,
dropout=0.1,
batch_first=True
)
self.norm1 = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.norm2 = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x, attention_mask=None):
residual = x
if attention_mask is not None:
key_padding_mask = (attention_mask == 0)
else:
key_padding_mask = None
attn_output, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=key_padding_mask,
need_weights=False
)
x = self.norm1(residual + self.dropout(attn_output))
residual = x
x_ffn = self.ffn(x)
x = self.norm2(residual + self.dropout(x_ffn))
return x
class HCAEModel(nn.Module):
def __init__(self, vocab_size=30522, hidden_size=384, max_seq_len=512,
conv_layers=5, attn_layers=3, num_heads=12):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(max_seq_len, hidden_size)
self.LayerNorm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
self.conv_blocks = nn.ModuleList([
ConvBlock(hidden_size) for _ in range(conv_layers)
])
self.attn_blocks = nn.ModuleList([
AttentionBlock(hidden_size, num_heads) for _ in range(attn_layers)
])
self.use_gradient_checkpointing = False
def forward(self, input_ids, attention_mask=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
x = words_embeddings + position_embeddings
x = self.LayerNorm(x)
x = self.dropout(x)
for i, block in enumerate(self.conv_blocks):
if self.use_gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*args):
return module(*args)
return custom_forward
x = checkpoint(create_custom_forward(block), x, use_reentrant=False)
else:
x = block(x)
for i, block in enumerate(self.attn_blocks):
if self.use_gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(hidden_states, mask):
return module(hidden_states, attention_mask=mask)
return custom_forward
x = checkpoint(create_custom_forward(block), x, attention_mask, use_reentrant=False)
else:
x = block(x, attention_mask=attention_mask)
if attention_mask is not None:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float()
sum_embeddings = torch.sum(x * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
sentence_embeddings = sum_embeddings / sum_mask
else:
sentence_embeddings = x.mean(dim=1)
return sentence_embeddings
if __name__ == "__main__":
model = HCAEModel()
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6:.2f} M")
batch_size = 32
seq_len = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
dummy_input = torch.randint(0, 30522, (batch_size, seq_len)).to(device)
dummy_mask = torch.ones((batch_size, seq_len)).to(device)
model.use_gradient_checkpointing = True
with torch.cuda.amp.autocast(dtype=torch.float16):
output = model(dummy_input, attention_mask=dummy_mask)
print(f"Output shape: {output.shape}")
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)
memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)
print(f"CUDA memory allocated: {memory_allocated:.2f} MB")
print(f"CUDA memory reserved: {memory_reserved:.2f} MB")