File size: 7,606 Bytes
83c4388
be761d6
 
 
 
 
 
 
 
83c4388
be761d6
 
 
83c4388
 
 
 
 
 
 
 
be761d6
83c4388
 
 
 
 
be761d6
83c4388
 
 
4248343
 
83c4388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8936db8
83c4388
 
be761d6
83c4388
be761d6
83c4388
be761d6
83c4388
 
be761d6
 
 
83c4388
 
be761d6
 
83c4388
 
 
 
 
 
 
be761d6
 
 
83c4388
be761d6
 
 
 
83c4388
be761d6
 
 
 
 
 
 
 
83c4388
be761d6
 
 
 
 
 
 
 
 
 
 
 
83c4388
be761d6
83c4388
be761d6
 
 
 
 
 
 
83c4388
be761d6
 
 
83c4388
be761d6
 
83c4388
 
 
be761d6
83c4388
be761d6
 
83c4388
 
 
be761d6
83c4388
be761d6
83c4388
 
 
 
 
 
 
be761d6
83c4388
be761d6
83c4388
 
 
 
 
be761d6
83c4388
 
 
 
be761d6
83c4388
 
be761d6
83c4388
 
 
be761d6
83c4388
 
 
be761d6
83c4388
be761d6
83c4388
be761d6
83c4388
be761d6
83c4388
 
be761d6
83c4388
 
 
 
 
 
 
 
 
 
 
 
be761d6
83c4388
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput

from .configuration_minimamba import MiniMambaConfig
from .model import Mamba2, Mamba2Config



class MiniMamba(PreTrainedModel):
    """
    A Hugging Face–style wrapper around a Mamba2 model, providing:
      • forward(...) returning a CausalLMOutput
      • support for HF training loops
      • a naive generate(...) method with top-k/top-p sampling
    """
    config_class = MiniMambaConfig  # Tells HF which config class to use

    def __init__(self, config: MiniMambaConfig) -> None:
        """
        Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel.
        """
        super().__init__(config)

        # If your config includes Mamba2-like parameters, you can build a Mamba2Config from it:
        mamba2_args = Mamba2Config(
            vocab_size=config.vocab_size,
            num_layers=config.num_layers,
            dim=config.dim,
            use_mem_eff_path=True,
            weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False,
            torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype,
        )

        # Internally hold a Mamba2 model
        self.mamba = Mamba2(config=mamba2_args)

        # Because HF wants the final linear to be part of this top-level model,
        # you *can* rely on Mamba2’s built-in embedding + output if you prefer.
        # Mamba2 already has self.tok_emb and self.output. 
        # So we typically do NOT need a separate embedding or lm_head here.
        #
        # We only do so if we want the “HF standard” tie-weights approach:
        #    self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        #    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        #    self.lm_head.weight = self.tok_emb.weight
        # 
        # But Mamba2 does that internally if config.weight_tying == True.

        # This is optional: store any device or dtype you might want
        self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu'
        if isinstance(config.torch_dtype, str):
            self.dtype_ = getattr(torch, config.torch_dtype)
        else:
            self.dtype_ = config.torch_dtype
        
        # Parameter initialization (HF calls them with self._init_weights in some flows).
        self.apply(self._init_weights)

        print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def forward(
        self,
        input_ids: torch.LongTensor,
        labels: torch.LongTensor = None,
        **kwargs
    ) -> CausalLMOutput:
        """
        Forward pass for causal language modeling. 
        Returns a CausalLMOutput that includes loss (if labels is provided) and logits.
        """
        # Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...)
        # but we only need the logits from the simple call:
        logits = self.mamba(input_ids)  # shape: [batch, seq_len, vocab_size]

        loss = None
        if labels is not None:
            # By default, huggingface GPT-like models shift the logits by one
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), 
                shift_labels.view(-1)
            )

        return CausalLMOutput(
            loss=loss,
            logits=logits,
        )

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.LongTensor,
        max_new_tokens: int = 50,
        temperature: float = 0.5,
        top_k: int = 50,
        top_p: float = 0.95,
        eos_token_id: int = None,
        pad_token_id: int = 0,
        **kwargs
    ):
        """
        A naive token-by-token generation loop (greedy + top-k/top-p + temperature).
        """
        # We'll accumulate new tokens in generated_ids
        generated_ids = input_ids.clone()

        for _ in range(max_new_tokens):
            # Forward pass to get logits for the last token
            outputs = self.forward(generated_ids)
            logits = outputs.logits[:, -1, :]  # shape: (batch_size, vocab_size)

            # Scale by temperature
            if temperature != 1.0:
                logits = logits / temperature

            # Filter
            logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # shape: (batch, 1)

            # Append
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            # If we have an EOS token, we can break early if all sequences have ended
            if eos_token_id is not None and (next_token == eos_token_id).all():
                break

        return generated_ids

    @staticmethod
    def top_k_top_p_filtering(
        logits: torch.Tensor,
        top_k: int = 50,
        top_p: float = 0.95,
        filter_value: float = float("-inf"),
    ):
        """
        Filters logits using top-k and/or nucleus (top-p) filtering.
        """
        # top_k
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
            logits[indices_to_remove] = filter_value

        # top_p (nucleus)
        if 0 < top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p

            # Shift right to keep also the first token above threshold
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = False

            # Scatter to get back to original indexing
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=1, index=sorted_indices, src=sorted_indices_to_remove
            )
            logits[indices_to_remove] = filter_value

        return logits

    def _init_weights(self, module):
        """
        HF calls _init_weights to initialize parameters. 
        If you prefer Mamba’s own init approach, you can call model.mamba.init_weights().
        """
        # As an example, we just call Mamba2's init routine for the entire submodel,
        # or do some standard PyTorch inits for linear layers, embeddings, etc.
        if isinstance(module, Mamba2):
            module.init_weights()  # Mamba2’s internal init
        elif isinstance(module, nn.Linear):
            # e.g. standard xavier or normal init
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        # If needed, do your specialized inits for other modules

    def _get_num_params(self):
        # Count trainable params, subtract duplicates if tying weights, etc.
        return sum(p.numel() for p in self.parameters() if p.requires_grad)