Quark-72M / modeling_quark.py
ThingsAI's picture
feat: repetition penalty in generate_text
3a388e1 verified
Raw
History Blame Contribute Delete
8.89 kB
"""
Quark-72M β€” wrapper HuggingFace che usa l'architettura originale di training.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_quark import QuarkConfig
# ── Architettura identica a train.py ─────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * rms).to(x.dtype) * self.scale
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim, max_seq_len, theta=10_000.0):
super().__init__()
# head_dim/theta come Python float, NON tensori gestiti da HF β€”
# evita corruzione da meta-device init durante from_pretrained()
self.head_dim = head_dim
self.theta = theta
self.max_seq_len = max_seq_len
self._max = 0
self.cos_cache = None
self.sin_cache = None
def _build_cache(self, seq_len, device, dtype):
# Ricalcola inv_freq da zero ogni volta β€” niente stato persistito
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.cos_cache = emb.cos()[None, None].to(dtype)
self.sin_cache = emb.sin()[None, None].to(dtype)
self._max = seq_len
@staticmethod
def _rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k):
T = q.size(2)
# Ricostruisce la cache se: mai costruita, troppo corta, o device/dtype cambiati
needs_rebuild = (
self.cos_cache is None
or T > self._max
or self.cos_cache.device != q.device
or self.cos_cache.dtype != q.dtype
)
if needs_rebuild:
self._build_cache(max(T, self.max_seq_len), q.device, q.dtype)
cos = self.cos_cache[:, :, :T, :]
sin = self.sin_cache[:, :, :T, :]
q = q * cos + self._rotate_half(q) * sin
k = k * cos + self._rotate_half(k) * sin
return q, k
class GroupedQueryAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.n_groups = cfg.n_heads // cfg.n_kv_heads
self.head_dim = cfg.head_dim
self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads * cfg.head_dim, bias=cfg.qkv_bias)
self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias)
self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias)
self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.d_model, bias=False)
self.rope = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_theta)
self.drop = cfg.dropout
def forward(self, x, **kwargs):
B, T, _ = x.shape
orig_dtype = x.dtype
# Cast a float32 prima di tutto per evitare overflow in RoPE e SDPA
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2).float()
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float()
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float()
q, k = self.rope(q, k)
if self.n_groups > 1:
k = k.repeat_interleave(self.n_groups, dim=1)
v = v.repeat_interleave(self.n_groups, dim=1)
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True,
dropout_p=self.drop if self.training else 0.0,
)
out = out.to(orig_dtype)
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(out)
class SwiGLUFFN(nn.Module):
def __init__(self, cfg):
super().__init__()
self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.norm_attn = RMSNorm(cfg.d_model, cfg.rms_eps)
self.attn = GroupedQueryAttention(cfg)
self.norm_ffn = RMSNorm(cfg.d_model, cfg.rms_eps)
self.ffn = SwiGLUFFN(cfg)
def forward(self, x, **kwargs):
x = x + self.attn(self.norm_attn(x))
x = x + self.ffn(self.norm_ffn(x))
return x
# ── HuggingFace wrapper ───────────────────────────────────────────────────────
class QuarkPreTrainedModel(PreTrainedModel):
config_class = QuarkConfig
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, 0.0, 0.02)
if hasattr(module, "bias") and module.bias is not None:
nn.init.zeros_(module.bias)
class QuarkForCausalLM(QuarkPreTrainedModel):
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model, config.rms_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embed_tokens.weight
self.post_init()
def get_input_embeddings(self): return self.embed_tokens
def set_input_embeddings(self, v): self.embed_tokens = v
def get_output_embeddings(self): return self.lm_head
def set_output_embeddings(self, v): self.lm_head = v
def tie_weights(self, **kwargs): self.lm_head.weight = self.embed_tokens.weight
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].contiguous().view(-1, config.vocab_size),
labels[:, 1:].contiguous().view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(loss=loss, logits=logits)
@torch.no_grad()
def generate_text(self, input_ids, max_new_tokens=200, temperature=0.7,
top_p=0.9, rep_penalty=1.0, eos_token_id=None):
ctx = input_ids.clone()
generated = []
for _ in range(max_new_tokens):
out = self(ctx[:, -self.config.max_seq_len:])
logits = out.logits[0, -1, :].float()
# Repetition penalty β€” penalizza token giΓ  visti nel contesto+generati
if rep_penalty != 1.0:
seen = set(ctx[0].tolist() + generated)
for tid in seen:
if logits[tid] > 0:
logits[tid] /= rep_penalty
else:
logits[tid] *= rep_penalty
if temperature <= 0 or logits.isnan().any():
token_id = logits.argmax().item()
else:
logits = logits - logits.max()
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_p, sorted_i = torch.sort(probs, descending=True)
cum_p = torch.cumsum(sorted_p, dim=-1)
sorted_p[(cum_p - sorted_p) > top_p] = 0.0
total = sorted_p.sum()
token_id = sorted_i[torch.multinomial(sorted_p / (total if total > 0 else 1), 1)].item()
generated.append(token_id)
token = torch.tensor([[token_id]], device=ctx.device)
ctx = torch.cat([ctx, token], dim=1)
if eos_token_id is not None and token_id == eos_token_id:
break
return ctx