NeuraCraft commited on
Commit
0a4eb39
·
1 Parent(s): bb8f99e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. lance_ai_model.py +11 -0
config.json CHANGED
@@ -8,7 +8,7 @@
8
  "AutoModelForCausalLM": "lance_ai_model.LanceAI"
9
  },
10
  "bos_token_id": 151643,
11
- "dtype": "bfloat16",
12
  "eos_token_id": 151643,
13
  "head_dim": 64,
14
  "hidden_act": "silu",
 
8
  "AutoModelForCausalLM": "lance_ai_model.LanceAI"
9
  },
10
  "bos_token_id": 151643,
11
+ "dtype": "float32",
12
  "eos_token_id": 151643,
13
  "head_dim": 64,
14
  "hidden_act": "silu",
lance_ai_model.py CHANGED
@@ -60,6 +60,7 @@ class LanceAIConfig(PretrainedConfig):
60
  self.bos_token_id = bos_token_id
61
  self.eos_token_id = eos_token_id
62
 
 
63
  class LanceAIRMSNorm(nn.Module):
64
  def __init__(self, hidden_size, eps=1e-6):
65
  super().__init__()
@@ -73,6 +74,7 @@ class LanceAIRMSNorm(nn.Module):
73
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
74
  return self.weight * hidden_states.to(input_dtype)
75
 
 
76
  class LanceAIRotaryEmbedding(nn.Module):
77
  def __init__(self, config):
78
  super().__init__()
@@ -92,6 +94,7 @@ class LanceAIRotaryEmbedding(nn.Module):
92
  sin = emb.sin()
93
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
94
 
 
95
  def rotate_half(x):
96
  x1 = x[..., :x.shape[-1] // 2]
97
  x2 = x[..., x.shape[-1] // 2:]
@@ -113,6 +116,7 @@ def repeat_kv(hidden_states, n_rep):
113
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
114
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
115
 
 
116
  def eager_attention_forward(
117
  module: nn.Module,
118
  query: torch.Tensor,
@@ -138,6 +142,7 @@ def eager_attention_forward(
138
 
139
  return attn_output, attn_weights
140
 
 
141
  class LanceAIAttention(nn.Module):
142
  def __init__(self, config, layer_idx):
143
  super().__init__()
@@ -200,6 +205,7 @@ class LanceAIAttention(nn.Module):
200
  attn_output = self.o_proj(attn_output)
201
  return attn_output, attn_weights
202
 
 
203
  class LanceAIMLP(nn.Module):
204
  def __init__(self, config):
205
  super().__init__()
@@ -212,6 +218,7 @@ class LanceAIMLP(nn.Module):
212
  def forward(self, x):
213
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
214
 
 
215
  class LanceAIDecoderLayer(GradientCheckpointingLayer):
216
  def __init__(self, config, layer_idx):
217
  super().__init__()
@@ -249,6 +256,7 @@ class LanceAIDecoderLayer(GradientCheckpointingLayer):
249
 
250
  return hidden_states
251
 
 
252
  class LanceAIPreTrainedModel(PreTrainedModel):
253
  config_class = LanceAIConfig
254
  base_model_prefix = "model"
@@ -260,6 +268,7 @@ class LanceAIPreTrainedModel(PreTrainedModel):
260
  _supports_flex_attn = True
261
  _can_compile_fullgraph = True
262
 
 
263
  class LanceAIModel(LanceAIPreTrainedModel):
264
  def __init__(self, config):
265
  super().__init__(config)
@@ -335,6 +344,7 @@ class LanceAIModel(LanceAIPreTrainedModel):
335
  mask = torch.triu(mask, diagonal=1 + past_len)
336
  return mask[None, None, :, :]
337
 
 
338
  class LanceAI(LanceAIPreTrainedModel, GenerationMixin):
339
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
340
 
@@ -404,6 +414,7 @@ class LanceAI(LanceAIPreTrainedModel, GenerationMixin):
404
  reordered.append((layer_k.index_select(0, beam_idx), layer_v.index_select(0, beam_idx)))
405
  return reordered
406
 
 
407
  CONFIG_MAPPING.register("lance_ai", LanceAIConfig)
408
  MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI)
409
  LanceAIConfig.register_for_auto_class("AutoConfig")
 
60
  self.bos_token_id = bos_token_id
61
  self.eos_token_id = eos_token_id
62
 
63
+
64
  class LanceAIRMSNorm(nn.Module):
65
  def __init__(self, hidden_size, eps=1e-6):
66
  super().__init__()
 
74
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
75
  return self.weight * hidden_states.to(input_dtype)
76
 
77
+
78
  class LanceAIRotaryEmbedding(nn.Module):
79
  def __init__(self, config):
80
  super().__init__()
 
94
  sin = emb.sin()
95
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
96
 
97
+
98
  def rotate_half(x):
99
  x1 = x[..., :x.shape[-1] // 2]
100
  x2 = x[..., x.shape[-1] // 2:]
 
116
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
117
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
118
 
119
+
120
  def eager_attention_forward(
121
  module: nn.Module,
122
  query: torch.Tensor,
 
142
 
143
  return attn_output, attn_weights
144
 
145
+
146
  class LanceAIAttention(nn.Module):
147
  def __init__(self, config, layer_idx):
148
  super().__init__()
 
205
  attn_output = self.o_proj(attn_output)
206
  return attn_output, attn_weights
207
 
208
+
209
  class LanceAIMLP(nn.Module):
210
  def __init__(self, config):
211
  super().__init__()
 
218
  def forward(self, x):
219
  return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
220
 
221
+
222
  class LanceAIDecoderLayer(GradientCheckpointingLayer):
223
  def __init__(self, config, layer_idx):
224
  super().__init__()
 
256
 
257
  return hidden_states
258
 
259
+
260
  class LanceAIPreTrainedModel(PreTrainedModel):
261
  config_class = LanceAIConfig
262
  base_model_prefix = "model"
 
268
  _supports_flex_attn = True
269
  _can_compile_fullgraph = True
270
 
271
+
272
  class LanceAIModel(LanceAIPreTrainedModel):
273
  def __init__(self, config):
274
  super().__init__(config)
 
344
  mask = torch.triu(mask, diagonal=1 + past_len)
345
  return mask[None, None, :, :]
346
 
347
+
348
  class LanceAI(LanceAIPreTrainedModel, GenerationMixin):
349
  _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
350
 
 
414
  reordered.append((layer_k.index_select(0, beam_idx), layer_v.index_select(0, beam_idx)))
415
  return reordered
416
 
417
+
418
  CONFIG_MAPPING.register("lance_ai", LanceAIConfig)
419
  MODEL_FOR_CAUSAL_LM_MAPPING.register(LanceAIConfig, LanceAI)
420
  LanceAIConfig.register_for_auto_class("AutoConfig")