avankumar commited on
Commit
f1a09b3
·
verified ·
1 Parent(s): 9cc5b38

Update modeling_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +64 -1
modeling_phi3.py CHANGED
@@ -7,6 +7,69 @@ from transformers.modeling_utils import PreTrainedModel
7
  from .configuration_phi3 import Phi3Config
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Phi3ForCausalLM(PreTrainedModel):
11
  config_class = Phi3Config
12
  model_type = "phi3"
@@ -32,7 +95,7 @@ class Phi3ForCausalLM(PreTrainedModel):
32
  return CausalLMOutputWithPast(
33
  loss=loss,
34
  logits=logits,
35
- past_key_values=None, # Optional: implement later
36
  hidden_states=None,
37
  attentions=None,
38
  )
 
7
  from .configuration_phi3 import Phi3Config
8
 
9
 
10
+ class Phi3Attention(nn.Module):
11
+ def __init__(self, config):
12
+ super().__init__()
13
+ self.num_heads = config.num_attention_heads
14
+ self.head_dim = config.hidden_size // config.num_attention_heads
15
+ self.scale = self.head_dim ** -0.5
16
+
17
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
18
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
19
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
20
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
21
+
22
+ def forward(self, x, mask=None):
23
+ B, T, C = x.size()
24
+
25
+ q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
26
+ k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
27
+ v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
28
+
29
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
30
+ if mask is not None:
31
+ attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
32
+ attn_probs = torch.softmax(attn_weights, dim=-1)
33
+
34
+ attn_output = torch.matmul(attn_probs, v).transpose(1, 2).contiguous().view(B, T, C)
35
+ return self.out_proj(attn_output)
36
+
37
+
38
+ class Phi3Block(nn.Module):
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
42
+ self.attn = Phi3Attention(config)
43
+ self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
44
+ self.mlp = nn.Sequential(
45
+ nn.Linear(config.hidden_size, config.intermediate_size),
46
+ nn.GELU(),
47
+ nn.Linear(config.intermediate_size, config.hidden_size)
48
+ )
49
+
50
+ def forward(self, x, mask=None):
51
+ x = x + self.attn(self.ln1(x), mask=mask)
52
+ x = x + self.mlp(self.ln2(x))
53
+ return x
54
+
55
+
56
+ class Phi3Model(PreTrainedModel):
57
+ config_class = Phi3Config
58
+
59
+ def __init__(self, config):
60
+ super().__init__(config)
61
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
62
+ self.blocks = nn.ModuleList([Phi3Block(config) for _ in range(config.num_hidden_layers)])
63
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
64
+
65
+ def forward(self, input_ids, attention_mask=None):
66
+ x = self.embed_tokens(input_ids)
67
+ for block in self.blocks:
68
+ x = block(x, attention_mask)
69
+ x = self.ln_f(x)
70
+ return x
71
+
72
+
73
  class Phi3ForCausalLM(PreTrainedModel):
74
  config_class = Phi3Config
75
  model_type = "phi3"
 
95
  return CausalLMOutputWithPast(
96
  loss=loss,
97
  logits=logits,
98
+ past_key_values=None, # Future: return actual cache if implemented
99
  hidden_states=None,
100
  attentions=None,
101
  )