Sh2425 commited on
Commit
6ccae00
·
verified ·
1 Parent(s): 6eaf811

Update modeling_dolphy.py

Browse files
Files changed (1) hide show
  1. modeling_dolphy.py +40 -67
modeling_dolphy.py CHANGED
@@ -1,67 +1,40 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from transformers.modeling_outputs import CausalLMOutputWithPast
5
-
6
- class RMSNorm(nn.Module):
7
- def __init__(self, hidden_size, eps=1e-6):
8
- super().__init__()
9
- self.weight = nn.Parameter(torch.ones(hidden_size))
10
- self.eps = eps
11
-
12
- def forward(self, x):
13
- norm = x.pow(2).mean(-1, keepdim=True)
14
- return self.weight * x * torch.rsqrt(norm + self.eps)
15
-
16
- class MLP(nn.Module):
17
- def __init__(self, hidden_size, intermediate_size):
18
- super().__init__()
19
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
20
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
21
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
22
-
23
- def forward(self, x):
24
- return self.down_proj(F.gelu(self.gate_proj(x)) * self.up_proj(x))
25
-
26
- class DolphyBlock(nn.Module):
27
- def __init__(self, hidden_size, intermediate_size, num_heads, fused=False):
28
- super().__init__()
29
- self.norm1 = RMSNorm(hidden_size)
30
- self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
31
- self.norm2 = RMSNorm(hidden_size)
32
- self.mlp = None if fused else MLP(hidden_size, intermediate_size)
33
-
34
- def forward(self, x, attn_mask=None):
35
- x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=attn_mask)[0]
36
- if self.mlp:
37
- x = x + self.mlp(self.norm2(x))
38
- return x
39
-
40
- class Dolphy1ForCausalLM(nn.Module):
41
- def __init__(self, vocab_size=32000, hidden_size=4096, intermediate_size=16384, num_layers=32, num_heads=32, moe_fused=True):
42
- super().__init__()
43
- self.embed = nn.Embedding(vocab_size, hidden_size)
44
- self.blocks = nn.ModuleList([
45
- DolphyBlock(hidden_size, intermediate_size, num_heads, fused=moe_fused) for _ in range(num_layers)
46
- ])
47
- self.norm = RMSNorm(hidden_size)
48
- self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
49
-
50
- def forward(self, input_ids, attention_mask=None, labels=None):
51
- x = self.embed(input_ids)
52
- for block in self.blocks:
53
- x = block(x, attention_mask)
54
- x = self.norm(x)
55
- logits = self.lm_head(x)
56
-
57
- loss = None
58
- if labels is not None:
59
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
60
-
61
- return CausalLMOutputWithPast(
62
- loss=loss,
63
- logits=logits,
64
- past_key_values=None,
65
- hidden_states=None,
66
- attentions=None,
67
- )
 
1
+ from transformers import PreTrainedModel
2
+ from transformers.modeling_outputs import CausalLMOutputWithPast
3
+ from torch import nn
4
+
5
+ class DolphyBlock(nn.Module):
6
+ def __init__(self, config):
7
+ super().__init__()
8
+ self.attn = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
9
+ self.mlp = nn.Linear(config.hidden_size, config.hidden_size) # placeholder
10
+
11
+ def forward(self, x):
12
+ x = self.attn(x)
13
+ x = self.mlp(x)
14
+ return x
15
+
16
+ class DolphyModel(nn.Module):
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
20
+ self.layers = nn.ModuleList([DolphyBlock(config) for _ in range(config.num_hidden_layers)])
21
+ self.norm = nn.LayerNorm(config.hidden_size)
22
+
23
+ def forward(self, input_ids):
24
+ x = self.embed_tokens(input_ids)
25
+ for layer in self.layers:
26
+ x = layer(x)
27
+ return self.norm(x)
28
+
29
+ class Dolphy1ForCausalLM(PreTrainedModel):
30
+ _auto_class = True
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+ self.model = DolphyModel(config)
35
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
36
+
37
+ def forward(self, input_ids, attention_mask=None, **kwargs):
38
+ hidden_states = self.model(input_ids)
39
+ logits = self.lm_head(hidden_states)
40
+ return CausalLMOutputWithPast(logits=logits)