""" Quark-72M — wrapper HuggingFace che usa l'architettura originale di training. """ import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_quark import QuarkConfig # ── Architettura identica a train.py ───────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * rms).to(x.dtype) * self.scale class RotaryEmbedding(nn.Module): def __init__(self, head_dim, max_seq_len, theta=10_000.0): super().__init__() # head_dim/theta come Python float, NON tensori gestiti da HF — # evita corruzione da meta-device init durante from_pretrained() self.head_dim = head_dim self.theta = theta self.max_seq_len = max_seq_len self._max = 0 self.cos_cache = None self.sin_cache = None def _build_cache(self, seq_len, device, dtype): # Ricalcola inv_freq da zero ogni volta — niente stato persistito inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.cos_cache = emb.cos()[None, None].to(dtype) self.sin_cache = emb.sin()[None, None].to(dtype) self._max = seq_len @staticmethod def _rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward(self, q, k): T = q.size(2) # Ricostruisce la cache se: mai costruita, troppo corta, o device/dtype cambiati needs_rebuild = ( self.cos_cache is None or T > self._max or self.cos_cache.device != q.device or self.cos_cache.dtype != q.dtype ) if needs_rebuild: self._build_cache(max(T, self.max_seq_len), q.device, q.dtype) cos = self.cos_cache[:, :, :T, :] sin = self.sin_cache[:, :, :T, :] q = q * cos + self._rotate_half(q) * sin k = k * cos + self._rotate_half(k) * sin return q, k class GroupedQueryAttention(nn.Module): def __init__(self, cfg): super().__init__() self.n_heads = cfg.n_heads self.n_kv_heads = cfg.n_kv_heads self.n_groups = cfg.n_heads // cfg.n_kv_heads self.head_dim = cfg.head_dim self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=cfg.qkv_bias) self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias) self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias) self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.d_model, bias=False) self.rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_theta) self.drop = cfg.dropout def forward(self, x, **kwargs): B, T, _ = x.shape orig_dtype = x.dtype # Cast a float32 prima di tutto per evitare overflow in RoPE e SDPA q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2).float() k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float() v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float() q, k = self.rope(q, k) if self.n_groups > 1: k = k.repeat_interleave(self.n_groups, dim=1) v = v.repeat_interleave(self.n_groups, dim=1) out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.drop if self.training else 0.0, ) out = out.to(orig_dtype) out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(out) class SwiGLUFFN(nn.Module): def __init__(self, cfg): super().__init__() self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.norm_attn = RMSNorm(cfg.d_model, cfg.rms_eps) self.attn = GroupedQueryAttention(cfg) self.norm_ffn = RMSNorm(cfg.d_model, cfg.rms_eps) self.ffn = SwiGLUFFN(cfg) def forward(self, x, **kwargs): x = x + self.attn(self.norm_attn(x)) x = x + self.ffn(self.norm_ffn(x)) return x # ── HuggingFace wrapper ─────────────────────────────────────────────────────── class QuarkPreTrainedModel(PreTrainedModel): config_class = QuarkConfig base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["lm_head.weight"] def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): nn.init.normal_(module.weight, 0.0, 0.02) if hasattr(module, "bias") and module.bias is not None: nn.init.zeros_(module.bias) class QuarkForCausalLM(QuarkPreTrainedModel): _keys_to_ignore_on_load_missing = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model, config.rms_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed_tokens.weight self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, v): self.embed_tokens = v def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, v): self.lm_head = v def tie_weights(self, **kwargs): self.lm_head.weight = self.embed_tokens.weight def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].contiguous().view(-1, config.vocab_size), labels[:, 1:].contiguous().view(-1), ignore_index=-100, ) return CausalLMOutputWithPast(loss=loss, logits=logits) @torch.no_grad() def generate_text(self, input_ids, max_new_tokens=200, temperature=0.7, top_p=0.9, rep_penalty=1.0, eos_token_id=None): ctx = input_ids.clone() generated = [] for _ in range(max_new_tokens): out = self(ctx[:, -self.config.max_seq_len:]) logits = out.logits[0, -1, :].float() # Repetition penalty — penalizza token già visti nel contesto+generati if rep_penalty != 1.0: seen = set(ctx[0].tolist() + generated) for tid in seen: if logits[tid] > 0: logits[tid] /= rep_penalty else: logits[tid] *= rep_penalty if temperature <= 0 or logits.isnan().any(): token_id = logits.argmax().item() else: logits = logits - logits.max() logits = logits / temperature probs = F.softmax(logits, dim=-1) sorted_p, sorted_i = torch.sort(probs, descending=True) cum_p = torch.cumsum(sorted_p, dim=-1) sorted_p[(cum_p - sorted_p) > top_p] = 0.0 total = sorted_p.sum() token_id = sorted_i[torch.multinomial(sorted_p / (total if total > 0 else 1), 1)].item() generated.append(token_id) token = torch.tensor([[token_id]], device=ctx.device) ctx = torch.cat([ctx, token], dim=1) if eos_token_id is not None and token_id == eos_token_id: break return ctx