avankumar commited on
Commit
9cc5b38
·
verified ·
1 Parent(s): b7469e3

Update modeling_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +10 -65
modeling_phi3.py CHANGED
@@ -7,77 +7,15 @@ from transformers.modeling_utils import PreTrainedModel
7
  from .configuration_phi3 import Phi3Config
8
 
9
 
10
- class Phi3Attention(nn.Module):
11
- def __init__(self, config):
12
- super().__init__()
13
- self.num_heads = config.num_attention_heads
14
- self.head_dim = config.hidden_size // config.num_attention_heads
15
- self.scale = self.head_dim ** -0.5
16
-
17
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
18
- self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
19
- self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
20
- self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
21
-
22
- def forward(self, x, mask=None):
23
- B, T, C = x.size()
24
-
25
- q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
26
- k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
27
- v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
28
-
29
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
30
- if mask is not None:
31
- attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
32
- attn_probs = torch.softmax(attn_weights, dim=-1)
33
-
34
- attn_output = torch.matmul(attn_probs, v).transpose(1, 2).contiguous().view(B, T, C)
35
- return self.out_proj(attn_output)
36
-
37
-
38
- class Phi3Block(nn.Module):
39
- def __init__(self, config):
40
- super().__init__()
41
- self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
42
- self.attn = Phi3Attention(config)
43
- self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
44
- self.mlp = nn.Sequential(
45
- nn.Linear(config.hidden_size, config.intermediate_size),
46
- nn.GELU(),
47
- nn.Linear(config.intermediate_size, config.hidden_size)
48
- )
49
-
50
- def forward(self, x, mask=None):
51
- x = x + self.attn(self.ln1(x), mask=mask)
52
- x = x + self.mlp(self.ln2(x))
53
- return x
54
-
55
-
56
- class Phi3Model(PreTrainedModel):
57
- config_class = Phi3Config
58
-
59
- def __init__(self, config):
60
- super().__init__(config)
61
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
62
- self.blocks = nn.ModuleList([Phi3Block(config) for _ in range(config.num_hidden_layers)])
63
- self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
64
-
65
- def forward(self, input_ids, attention_mask=None):
66
- x = self.embed_tokens(input_ids)
67
- for block in self.blocks:
68
- x = block(x, attention_mask)
69
- x = self.ln_f(x)
70
- return x
71
-
72
-
73
  class Phi3ForCausalLM(PreTrainedModel):
74
  config_class = Phi3Config
 
 
75
 
76
  def __init__(self, config):
77
  super().__init__(config)
78
  self.model = Phi3Model(config)
79
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
-
81
  self.post_init()
82
 
83
  def forward(self, input_ids, attention_mask=None, labels=None):
@@ -94,7 +32,14 @@ class Phi3ForCausalLM(PreTrainedModel):
94
  return CausalLMOutputWithPast(
95
  loss=loss,
96
  logits=logits,
97
- past_key_values=None,
98
  hidden_states=None,
99
  attentions=None,
100
  )
 
 
 
 
 
 
 
 
7
  from .configuration_phi3 import Phi3Config
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Phi3ForCausalLM(PreTrainedModel):
11
  config_class = Phi3Config
12
+ model_type = "phi3"
13
+ base_model_prefix = "model"
14
 
15
  def __init__(self, config):
16
  super().__init__(config)
17
  self.model = Phi3Model(config)
18
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
19
  self.post_init()
20
 
21
  def forward(self, input_ids, attention_mask=None, labels=None):
 
32
  return CausalLMOutputWithPast(
33
  loss=loss,
34
  logits=logits,
35
+ past_key_values=None, # Optional: implement later
36
  hidden_states=None,
37
  attentions=None,
38
  )
39
+
40
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
41
+ return {
42
+ "input_ids": input_ids,
43
+ "attention_mask": attention_mask,
44
+ "past_key_values": past_key_values,
45
+ }