File size: 8,893 Bytes
92a7d9d
434cb12
92a7d9d
 
 
 
 
 
 
 
 
 
434cb12
 
92a7d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
38feb38
 
 
 
4e1e026
38feb38
 
 
92a7d9d
4e1e026
38feb38
 
 
 
 
4e1e026
 
92a7d9d
 
 
 
 
 
 
 
 
4e1e026
 
 
 
 
 
 
 
 
5c0671e
 
 
 
 
92a7d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434cb12
92a7d9d
9716697
 
 
 
 
434cb12
92a7d9d
 
 
 
434cb12
92a7d9d
 
23eadd7
92a7d9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c0671e
92a7d9d
 
 
 
434cb12
 
92a7d9d
cd6716a
92a7d9d
5c0671e
92a7d9d
 
434cb12
 
 
92a7d9d
 
 
 
91da9fd
5c0671e
 
92a7d9d
 
 
 
434cb12
 
92a7d9d
 
5c0671e
 
434cb12
 
 
5c0671e
 
cd6716a
92a7d9d
5c0671e
92a7d9d
434cb12
5c0671e
92a7d9d
 
434cb12
92a7d9d
 
 
 
 
 
 
3a388e1
 
 
92a7d9d
 
 
3a388e1
 
 
 
 
 
 
 
 
 
434cb12
3a388e1
91da9fd
3a388e1
 
 
92a7d9d
3a388e1
434cb12
3a388e1
 
 
 
 
 
 
92a7d9d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
Quark-72M β€” wrapper HuggingFace che usa l'architettura originale di training.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_quark import QuarkConfig


# ── Architettura identica a train.py ─────────────────────────────────────────

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps   = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return (x.float() * rms).to(x.dtype) * self.scale


class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim, max_seq_len, theta=10_000.0):
        super().__init__()
        # head_dim/theta come Python float, NON tensori gestiti da HF β€”
        # evita corruzione da meta-device init durante from_pretrained()
        self.head_dim    = head_dim
        self.theta       = theta
        self.max_seq_len = max_seq_len
        self._max        = 0
        self.cos_cache   = None
        self.sin_cache   = None

    def _build_cache(self, seq_len, device, dtype):
        # Ricalcola inv_freq da zero ogni volta β€” niente stato persistito
        inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
        t        = torch.arange(seq_len, device=device, dtype=torch.float32)
        freqs    = torch.outer(t, inv_freq)
        emb      = torch.cat([freqs, freqs], dim=-1)
        self.cos_cache = emb.cos()[None, None].to(dtype)
        self.sin_cache = emb.sin()[None, None].to(dtype)
        self._max = seq_len

    @staticmethod
    def _rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q, k):
        T = q.size(2)
        # Ricostruisce la cache se: mai costruita, troppo corta, o device/dtype cambiati
        needs_rebuild = (
            self.cos_cache is None
            or T > self._max
            or self.cos_cache.device != q.device
            or self.cos_cache.dtype  != q.dtype
        )
        if needs_rebuild:
            self._build_cache(max(T, self.max_seq_len), q.device, q.dtype)
        cos = self.cos_cache[:, :, :T, :]
        sin = self.sin_cache[:, :, :T, :]
        q = q * cos + self._rotate_half(q) * sin
        k = k * cos + self._rotate_half(k) * sin
        return q, k


class GroupedQueryAttention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.n_heads    = cfg.n_heads
        self.n_kv_heads = cfg.n_kv_heads
        self.n_groups   = cfg.n_heads // cfg.n_kv_heads
        self.head_dim   = cfg.head_dim
        self.q_proj = nn.Linear(cfg.d_model, cfg.n_heads    * cfg.head_dim, bias=cfg.qkv_bias)
        self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias)
        self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_heads * cfg.head_dim, bias=cfg.qkv_bias)
        self.o_proj = nn.Linear(cfg.n_heads * cfg.head_dim, cfg.d_model,    bias=False)
        self.rope   = RotaryEmbedding(cfg.head_dim, cfg.max_seq_len, cfg.rope_theta)
        self.drop   = cfg.dropout

    def forward(self, x, **kwargs):
        B, T, _ = x.shape
        orig_dtype = x.dtype
        # Cast a float32 prima di tutto per evitare overflow in RoPE e SDPA
        q = self.q_proj(x).view(B, T, self.n_heads,    self.head_dim).transpose(1, 2).float()
        k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float()
        v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2).float()
        q, k = self.rope(q, k)
        if self.n_groups > 1:
            k = k.repeat_interleave(self.n_groups, dim=1)
            v = v.repeat_interleave(self.n_groups, dim=1)
        out = F.scaled_dot_product_attention(
            q, k, v, is_causal=True,
            dropout_p=self.drop if self.training else 0.0,
        )
        out = out.to(orig_dtype)
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
        return self.o_proj(out)


class SwiGLUFFN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
        self.up_proj   = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
        self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)

    def forward(self, x):
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.norm_attn = RMSNorm(cfg.d_model, cfg.rms_eps)
        self.attn      = GroupedQueryAttention(cfg)
        self.norm_ffn  = RMSNorm(cfg.d_model, cfg.rms_eps)
        self.ffn       = SwiGLUFFN(cfg)

    def forward(self, x, **kwargs):
        x = x + self.attn(self.norm_attn(x))
        x = x + self.ffn(self.norm_ffn(x))
        return x


# ── HuggingFace wrapper ───────────────────────────────────────────────────────

class QuarkPreTrainedModel(PreTrainedModel):
    config_class      = QuarkConfig
    base_model_prefix = "model"
    _keys_to_ignore_on_load_missing = ["lm_head.weight"]

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            nn.init.normal_(module.weight, 0.0, 0.02)
            if hasattr(module, "bias") and module.bias is not None:
                nn.init.zeros_(module.bias)


class QuarkForCausalLM(QuarkPreTrainedModel):
    _keys_to_ignore_on_load_missing = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        self.layers       = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.norm         = RMSNorm(config.d_model, config.rms_eps)
        self.lm_head      = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.embed_tokens.weight
        self.post_init()

    def get_input_embeddings(self):  return self.embed_tokens
    def set_input_embeddings(self, v): self.embed_tokens = v
    def get_output_embeddings(self): return self.lm_head
    def set_output_embeddings(self, v): self.lm_head = v
    def tie_weights(self, **kwargs): self.lm_head.weight = self.embed_tokens.weight

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x      = self.norm(x)
        logits = self.lm_head(x)
        loss   = None
        if labels is not None:
            loss = F.cross_entropy(
                logits[:, :-1].contiguous().view(-1, config.vocab_size),
                labels[:, 1:].contiguous().view(-1),
                ignore_index=-100,
            )
        return CausalLMOutputWithPast(loss=loss, logits=logits)

    @torch.no_grad()
    def generate_text(self, input_ids, max_new_tokens=200, temperature=0.7,
                      top_p=0.9, rep_penalty=1.0, eos_token_id=None):
        ctx       = input_ids.clone()
        generated = []
        for _ in range(max_new_tokens):
            out    = self(ctx[:, -self.config.max_seq_len:])
            logits = out.logits[0, -1, :].float()

            # Repetition penalty β€” penalizza token giΓ  visti nel contesto+generati
            if rep_penalty != 1.0:
                seen = set(ctx[0].tolist() + generated)
                for tid in seen:
                    if logits[tid] > 0:
                        logits[tid] /= rep_penalty
                    else:
                        logits[tid] *= rep_penalty

            if temperature <= 0 or logits.isnan().any():
                token_id = logits.argmax().item()
            else:
                logits = logits - logits.max()
                logits = logits / temperature
                probs  = F.softmax(logits, dim=-1)
                sorted_p, sorted_i = torch.sort(probs, descending=True)
                cum_p  = torch.cumsum(sorted_p, dim=-1)
                sorted_p[(cum_p - sorted_p) > top_p] = 0.0
                total  = sorted_p.sum()
                token_id = sorted_i[torch.multinomial(sorted_p / (total if total > 0 else 1), 1)].item()

            generated.append(token_id)
            token = torch.tensor([[token_id]], device=ctx.device)
            ctx   = torch.cat([ctx, token], dim=1)
            if eos_token_id is not None and token_id == eos_token_id:
                break
        return ctx