ThingsAI commited on
Commit
f950097
·
verified ·
1 Parent(s): ffe51ed

Upload modeling_quark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_quark.py +103 -0
modeling_quark.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn, torch.nn.functional as F
2
+ from typing import Optional
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from .configuration_quark import QuarkConfig
6
+
7
+ class QuarkRMSNorm(nn.Module):
8
+ def __init__(self, dim, eps=1e-5):
9
+ super().__init__(); self.eps=eps; self.scale=nn.Parameter(torch.ones(dim))
10
+ def forward(self, x):
11
+ return (x.float()*(x.float().pow(2).mean(-1,keepdim=True).add(self.eps).rsqrt())).to(x.dtype)*self.scale
12
+
13
+ class QuarkRoPE(nn.Module):
14
+ def __init__(self, hd, ml, th=10000.):
15
+ super().__init__()
16
+ self.register_buffer("inv",1./(th**(torch.arange(0,hd,2).float()/hd)),persistent=False); self._b(ml)
17
+ def _b(self, sl):
18
+ f=torch.outer(torch.arange(sl,device=self.inv.device).float(),self.inv); e=torch.cat([f,f],-1)
19
+ self.register_buffer("cos_c",e.cos()[None,None],persistent=False); self.register_buffer("sin_c",e.sin()[None,None],persistent=False); self._m=sl
20
+ @staticmethod
21
+ def _r(x): a,b=x.chunk(2,-1); return torch.cat([-b,a],-1)
22
+ def forward(self, q, k):
23
+ T=q.size(2)
24
+ if T>self._m: self._b(T)
25
+ c,s=self.cos_c[:,:,:T,:],self.sin_c[:,:,:T,:]
26
+ return q*c+self._r(q)*s, k*c+self._r(k)*s
27
+
28
+ class QuarkAttention(nn.Module):
29
+ def __init__(self, cfg):
30
+ super().__init__()
31
+ self.nh,self.nkv,self.ng,self.hd=cfg.n_heads,cfg.n_kv_heads,cfg.n_heads//cfg.n_kv_heads,cfg.head_dim
32
+ self.q_proj=nn.Linear(cfg.d_model,cfg.n_heads*cfg.head_dim,bias=cfg.qkv_bias)
33
+ self.k_proj=nn.Linear(cfg.d_model,cfg.n_kv_heads*cfg.head_dim,bias=cfg.qkv_bias)
34
+ self.v_proj=nn.Linear(cfg.d_model,cfg.n_kv_heads*cfg.head_dim,bias=cfg.qkv_bias)
35
+ self.o_proj=nn.Linear(cfg.n_heads*cfg.head_dim,cfg.d_model,bias=False)
36
+ self.rope=QuarkRoPE(cfg.head_dim,cfg.max_seq_len,cfg.rope_theta)
37
+ def forward(self, x):
38
+ B,T,_=x.shape
39
+ q=self.q_proj(x).view(B,T,self.nh,self.hd).transpose(1,2)
40
+ k=self.k_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
41
+ v=self.v_proj(x).view(B,T,self.nkv,self.hd).transpose(1,2)
42
+ q,k=self.rope(q,k); q,k=q.to(v.dtype),k.to(v.dtype)
43
+ if self.ng>1: k=k.repeat_interleave(self.ng,1); v=v.repeat_interleave(self.ng,1)
44
+ return self.o_proj(F.scaled_dot_product_attention(q,k,v,is_causal=True).transpose(1,2).contiguous().view(B,T,-1))
45
+
46
+ class QuarkFFN(nn.Module):
47
+ def __init__(self, cfg):
48
+ super().__init__()
49
+ self.gate_proj=nn.Linear(cfg.d_model,cfg.d_ff,bias=False)
50
+ self.up_proj=nn.Linear(cfg.d_model,cfg.d_ff,bias=False)
51
+ self.down_proj=nn.Linear(cfg.d_ff,cfg.d_model,bias=False)
52
+ def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))
53
+
54
+ class QuarkBlock(nn.Module):
55
+ def __init__(self, cfg):
56
+ super().__init__()
57
+ self.norm_attn=QuarkRMSNorm(cfg.d_model,cfg.rms_eps); self.attn=QuarkAttention(cfg)
58
+ self.norm_ffn=QuarkRMSNorm(cfg.d_model,cfg.rms_eps); self.ffn=QuarkFFN(cfg)
59
+ def forward(self, x):
60
+ x=x+self.attn(self.norm_attn(x)); return x+self.ffn(self.norm_ffn(x))
61
+
62
+ class QuarkPreTrainedModel(PreTrainedModel):
63
+ config_class=QuarkConfig; base_model_prefix="model"; supports_gradient_checkpointing=False
64
+ def _init_weights(self, module):
65
+ if isinstance(module, nn.Linear):
66
+ module.weight.data.normal_(0.0, 0.02)
67
+ if module.bias is not None: module.bias.data.zero_()
68
+ elif isinstance(module, nn.Embedding): module.weight.data.normal_(0.0, 0.02)
69
+
70
+ class QuarkForCausalLM(QuarkPreTrainedModel):
71
+ def __init__(self, config):
72
+ super().__init__(config); self.config=config
73
+ self.embed_tokens=nn.Embedding(config.vocab_size,config.d_model)
74
+ self.layers=nn.ModuleList([QuarkBlock(config) for _ in range(config.n_layers)])
75
+ self.norm=QuarkRMSNorm(config.d_model,config.rms_eps)
76
+ self.lm_head=nn.Linear(config.d_model,config.vocab_size,bias=False)
77
+ self.lm_head.weight=self.embed_tokens.weight # weight tying
78
+ self.post_init()
79
+
80
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
81
+ """Se lm_head.weight manca, copia da embed_tokens.weight (weight tying)"""
82
+ lm_key = f"{prefix}lm_head.weight"
83
+ emb_key = f"{prefix}embed_tokens.weight"
84
+ if lm_key not in state_dict and emb_key in state_dict:
85
+ state_dict[lm_key] = state_dict[emb_key].clone()
86
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
87
+
88
+ def get_input_embeddings(self): return self.embed_tokens
89
+ def set_input_embeddings(self, v): self.embed_tokens=v
90
+ def get_output_embeddings(self): return self.lm_head
91
+ def set_output_embeddings(self, v): self.lm_head=v
92
+
93
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
94
+ h=self.embed_tokens(input_ids)
95
+ for layer in self.layers: h=layer(h)
96
+ logits=self.lm_head(self.norm(h))
97
+ loss=None
98
+ if labels is not None:
99
+ loss=F.cross_entropy(logits[...,:-1,:].contiguous().view(-1,self.config.vocab_size),
100
+ labels[...,1:].contiguous().view(-1),ignore_index=-100)
101
+ return CausalLMOutputWithPast(loss=loss, logits=logits)
102
+
103
+ def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}