AlexAISandro commited on
Commit
4706a47
·
verified ·
1 Parent(s): e4c2335

Add custom modeling code

Browse files
Files changed (1) hide show
  1. modeling_nebula.py +12 -18
modeling_nebula.py CHANGED
@@ -12,17 +12,9 @@ class NebulaConfig(PretrainedConfig):
12
  def __init__(self, dim=1280, n_layers=14, n_heads=10, n_kv_heads=10, vocab_size=60729,
13
  multiple_of=256, ffn_dim_multiplier=8/3, norm_eps=1e-5, max_seq_len=2048,
14
  dropout=0.1, use_cache=True, **kwargs):
15
- self.dim = dim
16
- self.n_layers = n_layers
17
- self.n_heads = n_heads
18
- self.n_kv_heads = n_kv_heads
19
- self.vocab_size = vocab_size
20
- self.multiple_of = multiple_of
21
- self.ffn_dim_multiplier = ffn_dim_multiplier
22
- self.norm_eps = norm_eps
23
- self.max_seq_len = max_seq_len
24
- self.dropout = dropout
25
- self.use_cache = use_cache
26
  super().__init__(**kwargs)
27
 
28
  class RMSNorm(nn.Module):
@@ -48,7 +40,7 @@ class RoPE(nn.Module):
48
  self.register_buffer('cos_cached', freqs.cos(), persistent=False)
49
  self.register_buffer('sin_cached', freqs.sin(), persistent=False)
50
  def forward(self, x: torch.Tensor, start_pos: int = 0):
51
- seq_len = x.shape[-2] # Use -2 for sequence length dimension
52
  cos = self.cos_cached[start_pos : start_pos + seq_len]
53
  sin = self.sin_cached[start_pos : start_pos + seq_len]
54
  x1 = x[..., : self.dim // 2]
@@ -81,10 +73,10 @@ class Attention(nn.Module):
81
  self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
82
  self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)
83
  self.rope = RoPE(config)
84
- def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
85
- bs, n_kv_heads, seq_len, head_dim = x.shape
86
- if n_rep == 1: return x
87
- return x.unsqueeze(3).expand(bs, n_kv_heads, seq_len, n_rep, head_dim).reshape(bs, self.n_heads, seq_len, head_dim)
88
  def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
89
  bs, seq_len_q, _ = x.shape
90
  start_pos = past_key_values[0].shape[2] if past_key_values is not None else 0
@@ -99,7 +91,7 @@ class Attention(nn.Module):
99
  xk = torch.cat([past_k, xk], dim=2)
100
  xv = torch.cat([past_v, xv], dim=2)
101
  present_key_values = (xk, xv) if use_cache else None
102
- xk_rep, xv_rep = self.repeat_kv(xk, self.n_rep), self.repeat_kv(xv, self.n_rep)
103
  output = F.scaled_dot_product_attention(xq, xk_rep, xv_rep, attn_mask=attention_mask)
104
  output = output.transpose(1, 2).contiguous().view(bs, seq_len_q, -1)
105
  return self.wo(output), present_key_values
@@ -139,8 +131,10 @@ class NebulaForCausalLM(PreTrainedModel, GenerationMixin):
139
  use_cache = use_cache if use_cache is not None else self.config.use_cache
140
  x = self.dropout(self.model.tok_embeddings(input_ids))
141
  present_key_values_list = [] if use_cache else None
 
 
142
  for i, layer in enumerate(self.model.layers):
143
- past_kv = past_key_values[i] if past_key_values is not None else None
144
  x, present_kv = layer(x, past_key_values=past_kv, use_cache=use_cache, attention_mask=attention_mask)
145
  if use_cache and present_key_values_list is not None:
146
  present_key_values_list.append(present_kv)
 
12
  def __init__(self, dim=1280, n_layers=14, n_heads=10, n_kv_heads=10, vocab_size=60729,
13
  multiple_of=256, ffn_dim_multiplier=8/3, norm_eps=1e-5, max_seq_len=2048,
14
  dropout=0.1, use_cache=True, **kwargs):
15
+ self.dim, self.n_layers, self.n_heads, self.n_kv_heads = dim, n_layers, n_heads, n_kv_heads
16
+ self.vocab_size, self.multiple_of, self.ffn_dim_multiplier = vocab_size, multiple_of, ffn_dim_multiplier
17
+ self.norm_eps, self.max_seq_len, self.dropout, self.use_cache = norm_eps, max_seq_len, dropout, use_cache
 
 
 
 
 
 
 
 
18
  super().__init__(**kwargs)
19
 
20
  class RMSNorm(nn.Module):
 
40
  self.register_buffer('cos_cached', freqs.cos(), persistent=False)
41
  self.register_buffer('sin_cached', freqs.sin(), persistent=False)
42
  def forward(self, x: torch.Tensor, start_pos: int = 0):
43
+ seq_len = x.shape[-2]
44
  cos = self.cos_cached[start_pos : start_pos + seq_len]
45
  sin = self.sin_cached[start_pos : start_pos + seq_len]
46
  x1 = x[..., : self.dim // 2]
 
73
  self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
74
  self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)
75
  self.rope = RoPE(config)
76
+ def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
77
+ bs, n_kv_heads, seq_len_kv, head_dim = x.shape
78
+ if self.n_rep == 1: return x
79
+ return x.unsqueeze(3).expand(bs, n_kv_heads, seq_len_kv, self.n_rep, head_dim).reshape(bs, self.n_heads, seq_len_kv, head_dim)
80
  def forward(self, x: torch.Tensor, past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, attention_mask: Optional[torch.Tensor] = None):
81
  bs, seq_len_q, _ = x.shape
82
  start_pos = past_key_values[0].shape[2] if past_key_values is not None else 0
 
91
  xk = torch.cat([past_k, xk], dim=2)
92
  xv = torch.cat([past_v, xv], dim=2)
93
  present_key_values = (xk, xv) if use_cache else None
94
+ xk_rep, xv_rep = self.repeat_kv(xk), self.repeat_kv(xv)
95
  output = F.scaled_dot_product_attention(xq, xk_rep, xv_rep, attn_mask=attention_mask)
96
  output = output.transpose(1, 2).contiguous().view(bs, seq_len_q, -1)
97
  return self.wo(output), present_key_values
 
131
  use_cache = use_cache if use_cache is not None else self.config.use_cache
132
  x = self.dropout(self.model.tok_embeddings(input_ids))
133
  present_key_values_list = [] if use_cache else None
134
+ if past_key_values is None and use_cache:
135
+ past_key_values = tuple([None] * self.config.n_layers)
136
  for i, layer in enumerate(self.model.layers):
137
+ past_kv = past_key_values[i]
138
  x, present_kv = layer(x, past_key_values=past_kv, use_cache=use_cache, attention_mask=attention_mask)
139
  if use_cache and present_key_values_list is not None:
140
  present_key_values_list.append(present_kv)