nev8r commited on
Commit
5d054fe
·
verified ·
1 Parent(s): 2c01a11

Upload VerMind model

Browse files
chat_template.jinja ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0]['role'] == 'system' -%}
14
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
15
+ {%- else -%}
16
+ {{- '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}
17
+ {%- endif %}
18
+ {%- endif %}
19
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
20
+ {%- for message in messages[::-1] %}
21
+ {%- set index = (messages|length - 1) - loop.index0 %}
22
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
23
+ {%- set ns.multi_step_tool = false %}
24
+ {%- set ns.last_query_index = index %}
25
+ {%- endif %}
26
+ {%- endfor %}
27
+ {%- for message in messages %}
28
+ {%- if message.content is string %}
29
+ {%- set content = message.content %}
30
+ {%- else %}
31
+ {%- set content = '' %}
32
+ {%- endif %}
33
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
34
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
35
+ {%- elif message.role == "assistant" %}
36
+ {{- '<|im_start|>' + message.role + '\n' + content }}
37
+ {%- if message.tool_calls %}
38
+ {%- for tool_call in message.tool_calls %}
39
+ {%- if (loop.first and content) or (not loop.first) %}
40
+ {{- '\n' }}
41
+ {%- endif %}
42
+ {%- if tool_call.function %}
43
+ {%- set tool_call = tool_call.function %}
44
+ {%- endif %}
45
+ {{- '<tool_call>\n{"name": "' }}
46
+ {{- tool_call.name }}
47
+ {{- '", "arguments": ' }}
48
+ {%- if tool_call.arguments is string %}
49
+ {{- tool_call.arguments }}
50
+ {%- else %}
51
+ {{- tool_call.arguments | tojson }}
52
+ {%- endif %}
53
+ {{- '}\n</tool_call>' }}
54
+ {%- endfor %}
55
+ {%- endif %}
56
+ {{- '<|im_end|>\n' }}
57
+ {%- elif message.role == "tool" %}
58
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
59
+ {{- '<|im_start|>user' }}
60
+ {%- endif %}
61
+ {{- '\n<tool_response>\n' }}
62
+ {{- content }}
63
+ {{- '\n</tool_response>' }}
64
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
65
+ {{- '<|im_end|>\n' }}
66
+ {%- endif %}
67
+ {%- endif %}
68
+ {%- endfor %}
69
+ {%- if add_generation_prompt %}
70
+ {{- '<|im_start|>assistant\n' }}
71
+ {%- if enable_thinking is defined and enable_thinking is false %}
72
+ {{- '<think>\n\n</think>\n\n' }}
73
+ {%- endif %}
74
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VerMindForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_vermind.VerMindConfig",
7
+ "AutoModelForCausalLM": "modeling_vermind.VerMindForCausalLM"
8
+ },
9
+ "aux_loss_alpha": 0.01,
10
+ "bos_token_id": 1,
11
+ "dropout": 0.0,
12
+ "dtype": "float32",
13
+ "eos_token_id": 2,
14
+ "flash_attn": true,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 768,
17
+ "inference_rope_scaling": false,
18
+ "intermediate_size": 2048,
19
+ "max_position_embeddings": 32768,
20
+ "model_type": "vermind",
21
+ "n_routed_experts": 4,
22
+ "n_shared_experts": 1,
23
+ "norm_topk_prob": true,
24
+ "num_attention_heads": 8,
25
+ "num_experts_per_tok": 2,
26
+ "num_hidden_layers": 16,
27
+ "num_key_value_heads": 2,
28
+ "rms_norm_eps": 1e-05,
29
+ "rope_scaling": null,
30
+ "rope_theta": 1000000.0,
31
+ "scoring_func": "softmax",
32
+ "seq_aux": true,
33
+ "transformers_version": "4.57.6",
34
+ "use_moe": false,
35
+ "vocab_size": 6400
36
+ }
configuration_vermind.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Configuration file for VerMind model - Standalone Version
4
+ """
5
+
6
+ from transformers import PretrainedConfig, AutoConfig
7
+
8
+
9
+ class VerMindConfig(PretrainedConfig):
10
+ """Configuration class for VerMind model"""
11
+ model_type = "vermind"
12
+
13
+ def __init__(
14
+ self,
15
+ dropout: float = 0.0,
16
+ bos_token_id: int = 1,
17
+ eos_token_id: int = 2,
18
+ hidden_act: str = 'silu',
19
+ hidden_size: int = 768,
20
+ intermediate_size: int = None,
21
+ max_position_embeddings: int = 32768,
22
+ num_attention_heads: int = 8,
23
+ num_hidden_layers: int = 16,
24
+ num_key_value_heads: int = 2,
25
+ vocab_size: int = 6400,
26
+ rms_norm_eps: float = 1e-05,
27
+ rope_theta: float = 1000000.0,
28
+ inference_rope_scaling: bool = False,
29
+ flash_attn: bool = True,
30
+ use_moe: bool = False,
31
+ num_experts_per_tok: int = 2,
32
+ n_routed_experts: int = 4,
33
+ n_shared_experts: int = 1,
34
+ scoring_func: str = 'softmax',
35
+ aux_loss_alpha: float = 0.01,
36
+ seq_aux: bool = True,
37
+ norm_topk_prob: bool = True,
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.dropout = dropout
42
+ self.bos_token_id = bos_token_id
43
+ self.eos_token_id = eos_token_id
44
+ self.hidden_act = hidden_act
45
+ self.hidden_size = hidden_size
46
+ self.intermediate_size = intermediate_size
47
+ self.max_position_embeddings = max_position_embeddings
48
+ self.num_attention_heads = num_attention_heads
49
+ self.num_hidden_layers = num_hidden_layers
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.vocab_size = vocab_size
52
+ self.rms_norm_eps = rms_norm_eps
53
+ self.rope_theta = rope_theta
54
+ self.inference_rope_scaling = inference_rope_scaling
55
+
56
+ self.rope_scaling = {
57
+ "beta_fast": 32,
58
+ "beta_slow": 1,
59
+ "factor": 16,
60
+ "original_max_position_embeddings": 2048,
61
+ "attention_factor": 1.0,
62
+ "type": "yarn"
63
+ } if self.inference_rope_scaling else None
64
+ self.flash_attn = flash_attn
65
+
66
+ self.use_moe = use_moe
67
+ self.num_experts_per_tok = num_experts_per_tok
68
+ self.n_routed_experts = n_routed_experts
69
+ self.n_shared_experts = n_shared_experts
70
+ self.scoring_func = scoring_func
71
+ self.aux_loss_alpha = aux_loss_alpha
72
+ self.seq_aux = seq_aux
73
+ self.norm_topk_prob = norm_topk_prob
74
+
75
+
76
+ # Register the config class
77
+ AutoConfig.register("vermind", VerMindConfig)
78
+
79
+ __all__ = ["VerMindConfig"]
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.57.6"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea7ce1490895e993d593287156c563db5fea450f9924aeac2ae0cf4843dfb5e4
3
+ size 435801008
modeling_vermind.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """
3
+ Model file for VerMind model - Standalone Version
4
+ Contains complete implementation without external dependencies
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, Tuple, List, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedModel, GenerationMixin, AutoModelForCausalLM
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ from .configuration_vermind import VerMindConfig
18
+
19
+
20
+ # ==================== Base Module Functions ====================
21
+
22
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
23
+ rope_scaling: Optional[dict] = None):
24
+ """Precompute rotary position embedding frequencies"""
25
+ freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
26
+ if rope_scaling is not None:
27
+ orig_max, factor, beta_fast, beta_slow, attn_factor = (
28
+ rope_scaling.get("original_max_position_embeddings", 2048),
29
+ rope_scaling.get("factor", 16),
30
+ rope_scaling.get("beta_fast", 32.0),
31
+ rope_scaling.get("beta_slow", 1.0),
32
+ rope_scaling.get("attention_factor", 1.0)
33
+ )
34
+ if end / orig_max > 1.0:
35
+ inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
36
+ low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
37
+ ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
38
+ freqs = freqs * (1 - ramp + ramp / factor)
39
+
40
+ t = torch.arange(end, device=freqs.device)
41
+ freqs = torch.outer(t, freqs).float()
42
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
43
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
44
+ return freqs_cos, freqs_sin
45
+
46
+
47
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
48
+ """Apply rotary position embeddings to queries and keys"""
49
+ def rotate_half(x):
50
+ return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
51
+
52
+ # 保存原始 dtype
53
+ orig_dtype = q.dtype
54
+
55
+ if position_ids is not None:
56
+ if position_ids.dim() == 1:
57
+ pos_ids = position_ids
58
+ cos_selected = cos[pos_ids]
59
+ sin_selected = sin[pos_ids]
60
+ cos_selected = cos_selected.unsqueeze(0).unsqueeze(2)
61
+ sin_selected = sin_selected.unsqueeze(0).unsqueeze(2)
62
+ else:
63
+ cos_selected = cos[position_ids]
64
+ sin_selected = sin[position_ids]
65
+ cos_selected = cos_selected.unsqueeze(2)
66
+ sin_selected = sin_selected.unsqueeze(2)
67
+
68
+ q_embed = (q * cos_selected) + (rotate_half(q) * sin_selected)
69
+ k_embed = (k * cos_selected) + (rotate_half(k) * sin_selected)
70
+ else:
71
+ seq_len = q.shape[1]
72
+ cos_s = cos[:seq_len]
73
+ sin_s = sin[:seq_len]
74
+ cos_s = cos_s.unsqueeze(0).unsqueeze(2)
75
+ sin_s = sin_s.unsqueeze(0).unsqueeze(2)
76
+ q_embed = (q * cos_s) + (rotate_half(q) * sin_s)
77
+ k_embed = (k * cos_s) + (rotate_half(k) * sin_s)
78
+
79
+ # 转回原始 dtype
80
+ q_embed = q_embed.to(orig_dtype)
81
+ k_embed = k_embed.to(orig_dtype)
82
+ return q_embed, k_embed
83
+
84
+
85
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
86
+ """Repeat key/value heads for GQA"""
87
+ bs, slen, num_key_value_heads, head_dim = x.shape
88
+ if n_rep == 1:
89
+ return x
90
+ return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(
91
+ bs, slen, num_key_value_heads * n_rep, head_dim
92
+ )
93
+
94
+
95
+ # ==================== Module Classes ====================
96
+
97
+ class RMSNorm(nn.Module):
98
+ """Root Mean Square Layer Normalization"""
99
+ def __init__(self, dim: int, eps: float = 1e-5):
100
+ super().__init__()
101
+ self.eps = eps
102
+ self.weight = nn.Parameter(torch.ones(dim))
103
+
104
+ def _norm(self, x):
105
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
106
+
107
+ def forward(self, x):
108
+ return self.weight * self._norm(x.float()).type_as(x)
109
+
110
+
111
+ class FeedForward(nn.Module):
112
+ """SwiGLU Feed-Forward Network"""
113
+ def __init__(self, config: VerMindConfig):
114
+ super().__init__()
115
+ if config.intermediate_size is None:
116
+ intermediate_size = int(config.hidden_size * 8 / 3)
117
+ config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
118
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
119
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
120
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
121
+ self.dropout = nn.Dropout(config.dropout)
122
+ self.act_fn = ACT2FN[config.hidden_act]
123
+
124
+ def forward(self, x):
125
+ return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
126
+
127
+
128
+ class Attention(nn.Module):
129
+ """Grouped Query Attention with RoPE"""
130
+ def __init__(self, args: VerMindConfig):
131
+ super().__init__()
132
+ self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
133
+ assert args.num_attention_heads % self.num_key_value_heads == 0
134
+ self.n_local_heads = args.num_attention_heads
135
+ self.n_local_kv_heads = self.num_key_value_heads
136
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
137
+ self.head_dim = args.hidden_size // args.num_attention_heads
138
+ self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
139
+ self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
140
+ self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
141
+ self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
142
+ self.attn_dropout = nn.Dropout(args.dropout)
143
+ self.resid_dropout = nn.Dropout(args.dropout)
144
+ self.dropout = args.dropout
145
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
146
+
147
+ def forward(self, x, position_embeddings, past_key_value=None, use_cache=False,
148
+ attention_mask=None, position_ids=None, cu_seqlens=None):
149
+ bsz, seq_len, _ = x.shape
150
+ # 获取权重的 dtype(模型加载时的 dtype)
151
+ weight_dtype = self.q_proj.weight.dtype
152
+ if x.dtype != weight_dtype:
153
+ x = x.to(weight_dtype)
154
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
155
+ # 强制统一为权重 dtype(防止不同 proj 层 dtype 不一致)
156
+ xq = xq.to(weight_dtype)
157
+ xk = xk.to(weight_dtype)
158
+ xv = xv.to(weight_dtype)
159
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
160
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
161
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
162
+
163
+ cos, sin = position_embeddings
164
+ xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, position_ids=position_ids)
165
+
166
+ if past_key_value is not None:
167
+ xk = torch.cat([past_key_value[0], xk], dim=1)
168
+ xv = torch.cat([past_key_value[1], xv], dim=1)
169
+ past_kv = (xk, xv) if use_cache else None
170
+
171
+ xq, xk, xv = xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2)
172
+
173
+ is_2d_mask = attention_mask is not None and attention_mask.dim() == 3
174
+ attn_mask_for_flash = None
175
+ use_flash = False
176
+
177
+ if self.flash and (seq_len > 1) and (past_key_value is None):
178
+ if attention_mask is None:
179
+ use_flash = True
180
+ attn_mask_for_flash = None
181
+ elif is_2d_mask:
182
+ use_flash = False
183
+ elif torch.all(attention_mask == 1):
184
+ use_flash = True
185
+ attn_mask_for_flash = None
186
+ else:
187
+ use_flash = False
188
+
189
+ if use_flash:
190
+ if attn_mask_for_flash is not None:
191
+ output = F.scaled_dot_product_attention(
192
+ xq, xk, xv,
193
+ attn_mask=attn_mask_for_flash,
194
+ dropout_p=self.dropout if self.training else 0.0,
195
+ is_causal=False
196
+ )
197
+ else:
198
+ output = F.scaled_dot_product_attention(
199
+ xq, xk, xv,
200
+ dropout_p=self.dropout if self.training else 0.0,
201
+ is_causal=True
202
+ )
203
+ else:
204
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
205
+ if not is_2d_mask:
206
+ scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
207
+ if attention_mask is not None:
208
+ if is_2d_mask:
209
+ attention_mask = attention_mask[:, 0, :] if attention_mask.dim() == 3 else attention_mask
210
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
211
+ extended_attention_mask = (1.0 - extended_attention_mask.float()) * -1e9
212
+ scores = scores + extended_attention_mask
213
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
214
+ scores = self.attn_dropout(scores)
215
+ output = scores @ xv
216
+
217
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
218
+ output = self.resid_dropout(self.o_proj(output))
219
+ return output, past_kv
220
+
221
+
222
+ # ==================== Main Model Classes ====================
223
+
224
+ class VerMindBlock(nn.Module):
225
+ """Transformer Decoder Block"""
226
+ def __init__(self, layer_id: int, config: VerMindConfig):
227
+ super().__init__()
228
+ self.num_attention_heads = config.num_attention_heads
229
+ self.hidden_size = config.hidden_size
230
+ self.head_dim = config.hidden_size // config.num_attention_heads
231
+ self.self_attn = Attention(config)
232
+ self.layer_id = layer_id
233
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
234
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ self.mlp = FeedForward(config)
236
+
237
+ def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False,
238
+ attention_mask=None, position_ids=None, cu_seqlens=None):
239
+ residual = hidden_states
240
+ hidden_states, present_key_value = self.self_attn(
241
+ self.input_layernorm(hidden_states),
242
+ position_embeddings,
243
+ past_key_value,
244
+ use_cache,
245
+ attention_mask,
246
+ position_ids=position_ids,
247
+ cu_seqlens=cu_seqlens
248
+ )
249
+ hidden_states += residual
250
+ hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
251
+ return hidden_states, present_key_value
252
+
253
+
254
+ class VerMindModel(nn.Module):
255
+ """VerMind Model (Transformer backbone)"""
256
+ def __init__(self, config: VerMindConfig):
257
+ super().__init__()
258
+ self.config = config
259
+ self.vocab_size = config.vocab_size
260
+ self.num_hidden_layers = config.num_hidden_layers
261
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
262
+ self.dropout = nn.Dropout(config.dropout)
263
+ self.layers = nn.ModuleList([VerMindBlock(l, config) for l in range(self.num_hidden_layers)])
264
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
265
+
266
+ freqs_cos, freqs_sin = precompute_freqs_cis(
267
+ dim=config.hidden_size // config.num_attention_heads,
268
+ end=config.max_position_embeddings,
269
+ rope_base=config.rope_theta,
270
+ rope_scaling=config.rope_scaling
271
+ )
272
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
273
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
274
+
275
+ def forward(self, input_ids=None, attention_mask=None, past_key_values=None,
276
+ use_cache=False, position_ids=None, cu_seqlens=None, **kwargs):
277
+ if past_key_values is not None and hasattr(past_key_values, 'layers'):
278
+ past_key_values = None
279
+ past_key_values = past_key_values or [None] * len(self.layers)
280
+ start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
281
+
282
+ hidden_states = self.dropout(self.embed_tokens(input_ids))
283
+ position_embeddings = (self.freqs_cos, self.freqs_sin)
284
+
285
+ presents = []
286
+ for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
287
+ hidden_states, present = layer(
288
+ hidden_states,
289
+ position_embeddings,
290
+ past_key_value=past_key_value,
291
+ use_cache=use_cache,
292
+ attention_mask=attention_mask,
293
+ position_ids=position_ids,
294
+ cu_seqlens=cu_seqlens
295
+ )
296
+ presents.append(present)
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ aux_loss = 0
300
+ return hidden_states, presents, aux_loss
301
+
302
+
303
+ class VerMindForCausalLM(PreTrainedModel, GenerationMixin):
304
+ """VerMind Causal Language Model"""
305
+ config_class = VerMindConfig
306
+
307
+ def __init__(self, config: VerMindConfig = None):
308
+ self.config = config or VerMindConfig()
309
+ super().__init__(self.config)
310
+ self.model = VerMindModel(self.config)
311
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
312
+ self.model.embed_tokens.weight = self.lm_head.weight
313
+
314
+ def forward(self, input_ids=None, attention_mask=None, labels=None,
315
+ past_key_values=None, use_cache=False, logits_to_keep=0,
316
+ position_ids=None, cu_seqlens=None, **args):
317
+ hidden_states, past_key_values, aux_loss = self.model(
318
+ input_ids=input_ids,
319
+ attention_mask=attention_mask,
320
+ past_key_values=past_key_values,
321
+ use_cache=use_cache,
322
+ position_ids=position_ids,
323
+ cu_seqlens=cu_seqlens,
324
+ **args
325
+ )
326
+
327
+ is_varlen = cu_seqlens is not None
328
+ if is_varlen:
329
+ logits = self.lm_head(hidden_states)
330
+ else:
331
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
332
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
333
+
334
+ loss = None
335
+ if labels is not None:
336
+ if is_varlen:
337
+ shift_logits = logits[:-1, :].contiguous()
338
+ shift_labels = labels[1:].contiguous()
339
+ loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
340
+ else:
341
+ shift_logits = logits[..., :-1, :].contiguous()
342
+ shift_labels = labels[..., 1:].contiguous()
343
+ loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
344
+
345
+ output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
346
+ output.aux_loss = aux_loss
347
+ return output
348
+
349
+
350
+ # Register the model class
351
+ AutoModelForCausalLM.register(VerMindForCausalLM.config_class, VerMindForCausalLM)
352
+
353
+ __all__ = ["VerMindForCausalLM", "VerMindModel", "VerMindBlock", "Attention", "FeedForward", "RMSNorm"]
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|im_start|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "additional_special_tokens": [],
32
+ "bos_token": "<|im_start|>",
33
+ "clean_up_tokenization_spaces": false,
34
+ "eos_token": "<|im_end|>",
35
+ "extra_special_tokens": {},
36
+ "legacy": true,
37
+ "model_max_length": 32768,
38
+ "pad_token": "<|endoftext|>",
39
+ "sp_model_kwargs": {},
40
+ "spaces_between_special_tokens": false,
41
+ "tokenizer_class": "PreTrainedTokenizerFast",
42
+ "unk_token": "<|endoftext|>"
43
+ }