""" Quark model implementation for HuggingFace Transformers. Usage: from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("ThingAI/Quark-135m-v0.2", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("ThingAI/Quark-135m-v0.2") inputs = tokenizer("Ciao, come stai?", return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7, do_sample=True) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_quark import QuarkConfig class QuarkRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * rms).to(x.dtype) * self.scale class QuarkRotaryEmbedding(nn.Module): def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10000.0): super().__init__() assert head_dim % 2 == 0 inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._build_cache(max_seq_len) def _build_cache(self, seq_len: int): t = torch.arange(seq_len, device=self.inv_freq.device).float() freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cache", emb.cos()[None, None], persistent=False) self.register_buffer("sin_cache", emb.sin()[None, None], persistent=False) self._max_cached = seq_len @staticmethod def _rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward(self, q, k): T = q.size(2) if T > self._max_cached: self._build_cache(T) cos = self.cos_cache[:, :, :T, :] sin = self.sin_cache[:, :, :T, :] q = q * cos + self._rotate_half(q) * sin k = k * cos + self._rotate_half(k) * sin return q, k class QuarkAttention(nn.Module): """Grouped Query Attention (GQA).""" def __init__(self, config: QuarkConfig): super().__init__() self.n_heads = config.n_heads self.n_kv_heads = config.n_kv_heads self.n_groups = config.n_heads // config.n_kv_heads self.head_dim = config.head_dim self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=config.qkv_bias) self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=config.qkv_bias) self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=config.qkv_bias) self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) self.rope = QuarkRotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k = self.rope(q, k) if self.n_groups > 1: k = k.repeat_interleave(self.n_groups, dim=1) v = v.repeat_interleave(self.n_groups, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(out) class QuarkFFN(nn.Module): """SwiGLU Feed-Forward Network.""" def __init__(self, config: QuarkConfig): super().__init__() self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class QuarkBlock(nn.Module): """Transformer block with pre-norm.""" def __init__(self, config: QuarkConfig): super().__init__() self.norm_attn = QuarkRMSNorm(config.d_model, config.rms_eps) self.attn = QuarkAttention(config) self.norm_ffn = QuarkRMSNorm(config.d_model, config.rms_eps) self.ffn = QuarkFFN(config) def forward(self, x): x = x + self.attn(self.norm_attn(x)) x = x + self.ffn(self.norm_ffn(x)) return x class QuarkPreTrainedModel(PreTrainedModel): config_class = QuarkConfig base_model_prefix = "model" supports_gradient_checkpointing = False def _init_weights(self, module): std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) class QuarkForCausalLM(QuarkPreTrainedModel): """Quark model for causal language modeling.""" def __init__(self, config: QuarkConfig): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([QuarkBlock(config) for _ in range(config.n_layers)]) self.norm = QuarkRMSNorm(config.d_model, config.rms_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying self.lm_head.weight = self.embed_tokens.weight self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: h = self.embed_tokens(input_ids) for layer in self.layers: h = layer(h) h = self.norm(h) logits = self.lm_head(h) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast( loss=loss, logits=logits, ) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}