Quark-135m-Bilingual / modeling_quark.py
ThingsAI's picture
Upload modeling_quark.py
8c9cdca verified
"""
Quark model implementation for HuggingFace Transformers.
Usage:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ThingAI/Quark-135m-v0.2", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ThingAI/Quark-135m-v0.2")
inputs = tokenizer("Ciao, come stai?", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_quark import QuarkConfig
class QuarkRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * rms).to(x.dtype) * self.scale
class QuarkRotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device).float()
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cache", emb.cos()[None, None], persistent=False)
self.register_buffer("sin_cache", emb.sin()[None, None], persistent=False)
self._max_cached = seq_len
@staticmethod
def _rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k):
T = q.size(2)
if T > self._max_cached:
self._build_cache(T)
cos = self.cos_cache[:, :, :T, :]
sin = self.sin_cache[:, :, :T, :]
q = q * cos + self._rotate_half(q) * sin
k = k * cos + self._rotate_half(k) * sin
return q, k
class QuarkAttention(nn.Module):
"""Grouped Query Attention (GQA)."""
def __init__(self, config: QuarkConfig):
super().__init__()
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.n_groups = config.n_heads // config.n_kv_heads
self.head_dim = config.head_dim
self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=config.qkv_bias)
self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=config.qkv_bias)
self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
self.rope = QuarkRotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)
def forward(self, x):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k)
if self.n_groups > 1:
k = k.repeat_interleave(self.n_groups, dim=1)
v = v.repeat_interleave(self.n_groups, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(out)
class QuarkFFN(nn.Module):
"""SwiGLU Feed-Forward Network."""
def __init__(self, config: QuarkConfig):
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class QuarkBlock(nn.Module):
"""Transformer block with pre-norm."""
def __init__(self, config: QuarkConfig):
super().__init__()
self.norm_attn = QuarkRMSNorm(config.d_model, config.rms_eps)
self.attn = QuarkAttention(config)
self.norm_ffn = QuarkRMSNorm(config.d_model, config.rms_eps)
self.ffn = QuarkFFN(config)
def forward(self, x):
x = x + self.attn(self.norm_attn(x))
x = x + self.ffn(self.norm_ffn(x))
return x
class QuarkPreTrainedModel(PreTrainedModel):
config_class = QuarkConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class QuarkForCausalLM(QuarkPreTrainedModel):
"""Quark model for causal language modeling."""
def __init__(self, config: QuarkConfig):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([QuarkBlock(config) for _ in range(config.n_layers)])
self.norm = QuarkRMSNorm(config.d_model, config.rms_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying
self.lm_head.weight = self.embed_tokens.weight
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
h = self.embed_tokens(input_ids)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
logits = self.lm_head(h)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}