Sh2425 commited on
Commit
d16bbc2
·
verified ·
1 Parent(s): 2065930

Upload modeling_dolphy.py

Browse files
Files changed (1) hide show
  1. modeling_dolphy.py +67 -0
modeling_dolphy.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+
6
+ class RMSNorm(nn.Module):
7
+ def __init__(self, hidden_size, eps=1e-6):
8
+ super().__init__()
9
+ self.weight = nn.Parameter(torch.ones(hidden_size))
10
+ self.eps = eps
11
+
12
+ def forward(self, x):
13
+ norm = x.pow(2).mean(-1, keepdim=True)
14
+ return self.weight * x * torch.rsqrt(norm + self.eps)
15
+
16
+ class MLP(nn.Module):
17
+ def __init__(self, hidden_size, intermediate_size):
18
+ super().__init__()
19
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
20
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
21
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
22
+
23
+ def forward(self, x):
24
+ return self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))
25
+
26
+ class DolphyBlock(nn.Module):
27
+ def __init__(self, hidden_size, intermediate_size, num_heads, fused=False):
28
+ super().__init__()
29
+ self.norm1 = RMSNorm(hidden_size)
30
+ self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
31
+ self.norm2 = RMSNorm(hidden_size)
32
+ self.mlp = None if fused else MLP(hidden_size, intermediate_size)
33
+
34
+ def forward(self, x, attn_mask=None):
35
+ x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=attn_mask)[0]
36
+ if self.mlp:
37
+ x = x + self.mlp(self.norm2(x))
38
+ return x
39
+
40
+ class Dolphy1ForCausalLM(nn.Module):
41
+ def __init__(self, vocab_size=32000, hidden_size=4096, intermediate_size=16384, num_layers=32, num_heads=32, moe_fused=True):
42
+ super().__init__()
43
+ self.embed = nn.Embedding(vocab_size, hidden_size)
44
+ self.blocks = nn.ModuleList([
45
+ DolphyBlock(hidden_size, intermediate_size, num_heads, fused=moe_fused) for _ in range(num_layers)
46
+ ])
47
+ self.norm = RMSNorm(hidden_size)
48
+ self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
49
+
50
+ def forward(self, input_ids, attention_mask=None, labels=None):
51
+ x = self.embed(input_ids)
52
+ for block in self.blocks:
53
+ x = block(x, attention_mask)
54
+ x = self.norm(x)
55
+ logits = self.lm_head(x)
56
+
57
+ loss = None
58
+ if labels is not None:
59
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
60
+
61
+ return CausalLMOutputWithPast(
62
+ loss=loss,
63
+ logits=logits,
64
+ past_key_values=None,
65
+ hidden_states=None,
66
+ attentions=None,
67
+ )