HeavensHackDev commited on
Commit
dd4e5d3
·
verified ·
1 Parent(s): 4bbfb5c

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +166 -0
model.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import BertConfig
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ class ConvBlock(nn.Module):
8
+ def __init__(self, hidden_size, kernel_size=3, padding=1):
9
+ super().__init__()
10
+ self.conv_dw = nn.Conv1d(
11
+ in_channels=hidden_size,
12
+ out_channels=hidden_size,
13
+ kernel_size=kernel_size,
14
+ padding=padding,
15
+ groups=hidden_size
16
+ )
17
+ self.conv_pw = nn.Conv1d(
18
+ in_channels=hidden_size,
19
+ out_channels=hidden_size,
20
+ kernel_size=1
21
+ )
22
+ self.norm1 = nn.LayerNorm(hidden_size)
23
+ self.ffn = nn.Sequential(
24
+ nn.Linear(hidden_size, hidden_size * 4),
25
+ nn.GELU(),
26
+ nn.Linear(hidden_size * 4, hidden_size)
27
+ )
28
+ self.norm2 = nn.LayerNorm(hidden_size)
29
+ self.dropout = nn.Dropout(0.1)
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+ x_conv = x.transpose(1, 2)
34
+ x_conv = self.conv_dw(x_conv)
35
+ x_conv = self.conv_pw(x_conv)
36
+ x_conv = x_conv.transpose(1, 2)
37
+ x = self.norm1(residual + self.dropout(x_conv))
38
+ residual = x
39
+ x_ffn = self.ffn(x)
40
+ x = self.norm2(residual + self.dropout(x_ffn))
41
+ return x
42
+
43
+
44
+ class AttentionBlock(nn.Module):
45
+ def __init__(self, hidden_size, num_heads):
46
+ super().__init__()
47
+ self.self_attn = nn.MultiheadAttention(
48
+ embed_dim=hidden_size,
49
+ num_heads=num_heads,
50
+ dropout=0.1,
51
+ batch_first=True
52
+ )
53
+ self.norm1 = nn.LayerNorm(hidden_size)
54
+ self.ffn = nn.Sequential(
55
+ nn.Linear(hidden_size, hidden_size * 4),
56
+ nn.GELU(),
57
+ nn.Linear(hidden_size * 4, hidden_size)
58
+ )
59
+ self.norm2 = nn.LayerNorm(hidden_size)
60
+ self.dropout = nn.Dropout(0.1)
61
+
62
+ def forward(self, x, attention_mask=None):
63
+ residual = x
64
+ if attention_mask is not None:
65
+ key_padding_mask = (attention_mask == 0)
66
+ else:
67
+ key_padding_mask = None
68
+
69
+ attn_output, _ = self.self_attn(
70
+ query=x,
71
+ key=x,
72
+ value=x,
73
+ key_padding_mask=key_padding_mask,
74
+ need_weights=False
75
+ )
76
+ x = self.norm1(residual + self.dropout(attn_output))
77
+ residual = x
78
+ x_ffn = self.ffn(x)
79
+ x = self.norm2(residual + self.dropout(x_ffn))
80
+ return x
81
+
82
+
83
+ class HCAEModel(nn.Module):
84
+ def __init__(self, vocab_size=30522, hidden_size=384, max_seq_len=512,
85
+ conv_layers=5, attn_layers=3, num_heads=12):
86
+ super().__init__()
87
+ self.vocab_size = vocab_size
88
+ self.hidden_size = hidden_size
89
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
90
+ self.position_embeddings = nn.Embedding(max_seq_len, hidden_size)
91
+ self.LayerNorm = nn.LayerNorm(hidden_size)
92
+ self.dropout = nn.Dropout(0.1)
93
+ self.conv_blocks = nn.ModuleList([
94
+ ConvBlock(hidden_size) for _ in range(conv_layers)
95
+ ])
96
+ self.attn_blocks = nn.ModuleList([
97
+ AttentionBlock(hidden_size, num_heads) for _ in range(attn_layers)
98
+ ])
99
+ self.use_gradient_checkpointing = False
100
+
101
+ def forward(self, input_ids, attention_mask=None):
102
+ seq_length = input_ids.size(1)
103
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
104
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
105
+
106
+ words_embeddings = self.word_embeddings(input_ids)
107
+ position_embeddings = self.position_embeddings(position_ids)
108
+ x = words_embeddings + position_embeddings
109
+ x = self.LayerNorm(x)
110
+ x = self.dropout(x)
111
+
112
+ for i, block in enumerate(self.conv_blocks):
113
+ if self.use_gradient_checkpointing and self.training:
114
+ def create_custom_forward(module):
115
+ def custom_forward(*args):
116
+ return module(*args)
117
+ return custom_forward
118
+ x = checkpoint(create_custom_forward(block), x, use_reentrant=False)
119
+ else:
120
+ x = block(x)
121
+
122
+ for i, block in enumerate(self.attn_blocks):
123
+ if self.use_gradient_checkpointing and self.training:
124
+ def create_custom_forward(module):
125
+ def custom_forward(hidden_states, mask):
126
+ return module(hidden_states, attention_mask=mask)
127
+ return custom_forward
128
+ x = checkpoint(create_custom_forward(block), x, attention_mask, use_reentrant=False)
129
+ else:
130
+ x = block(x, attention_mask=attention_mask)
131
+
132
+ if attention_mask is not None:
133
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(x.size()).float()
134
+ sum_embeddings = torch.sum(x * input_mask_expanded, 1)
135
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
136
+ sentence_embeddings = sum_embeddings / sum_mask
137
+ else:
138
+ sentence_embeddings = x.mean(dim=1)
139
+
140
+ return sentence_embeddings
141
+
142
+ if __name__ == "__main__":
143
+ model = HCAEModel()
144
+ total_params = sum(p.numel() for p in model.parameters())
145
+ print(f"Total parameters: {total_params / 1e6:.2f} M")
146
+
147
+ batch_size = 32
148
+ seq_len = 128
149
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
150
+ model.to(device)
151
+
152
+ dummy_input = torch.randint(0, 30522, (batch_size, seq_len)).to(device)
153
+ dummy_mask = torch.ones((batch_size, seq_len)).to(device)
154
+
155
+ model.use_gradient_checkpointing = True
156
+
157
+ with torch.cuda.amp.autocast(dtype=torch.float16):
158
+ output = model(dummy_input, attention_mask=dummy_mask)
159
+
160
+ print(f"Output shape: {output.shape}")
161
+
162
+ if torch.cuda.is_available():
163
+ memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 2)
164
+ memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 2)
165
+ print(f"CUDA memory allocated: {memory_allocated:.2f} MB")
166
+ print(f"CUDA memory reserved: {memory_reserved:.2f} MB")