robinfaro commited on
Commit
88d2a93
·
verified ·
1 Parent(s): c4f6a89

Adding files from hf_modeling_btm

Browse files
Files changed (1) hide show
  1. modeling.py +249 -0
modeling.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration import MoLMConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from transformers.utils import ModelOutput
7
+ from .gpt import GPTBase
8
+ from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
9
+ from typing import Optional, List
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class Output(ModelOutput):
15
+ logits: torch.FloatTensor = None
16
+ loss: Optional[torch.FloatTensor] = None
17
+ expert_losses: Optional[List] = None
18
+ loss_to_log: Optional[float] = None
19
+ router_logits: Optional[torch.FloatTensor] = None
20
+ selected_experts: Optional[torch.LongTensor] = None
21
+ combined_log_probs: Optional[torch.FloatTensor] = None
22
+
23
+
24
+ class MoLM(PreTrainedModel):
25
+ config_class = MoLMConfig
26
+
27
+ def __init__(self, config, expert_weights=None, dropout=0.1):
28
+ """
29
+ Constructor for the MoLM (Mixture of Language Models) class.
30
+
31
+ :param config: The configuration of the model (should be a PretrainedConfig object)
32
+ :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts)
33
+ :param dropout: Dropout rate for the model
34
+ :param use_router: Flag to indicate whether to use routing (currently not implemented)
35
+ """
36
+ super(MoLM, self).__init__(config)
37
+
38
+ # Number of experts
39
+ self.num_experts = config.num_experts
40
+ print(f"Number of experts: {self.num_experts}")
41
+ print(f"Expert configurations: {config.expert_configs}")
42
+ assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
43
+ self.expert_configs = config.expert_configs
44
+
45
+
46
+ self.use_router = config.use_router
47
+
48
+ self.router = nn.Sequential(
49
+ nn.Linear(config.n_embd, self.num_experts),
50
+ )
51
+ self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts
52
+
53
+ # Initialize experts using the provided configurations
54
+ self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
55
+
56
+ # Load pre-trained weights if provided
57
+ if expert_weights is not None:
58
+ for i, expert in enumerate(self.experts):
59
+ expert.load_state_dict(expert_weights[i], strict=False)
60
+ expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone())
61
+ for param in expert.parameters():
62
+ param.requires_grad = False
63
+
64
+ def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs):
65
+ """
66
+ Forward pass for the MoLM model, passing input through all experts and averaging their outputs.
67
+
68
+ :param input_ids: Input token IDs (batch_size, seq_len)
69
+ :param attention_mask: Attention mask (batch_size, seq_len)
70
+ :param targets: Target labels for calculating loss (batch_size, seq_len)
71
+ :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date.
72
+ :param masking_enabled: Whether or not to perform expert masking (True/False)
73
+ :param kwargs: Additional arguments
74
+ :return: The averaged output of all active experts up to the specified date for each sample in the batch
75
+ """
76
+ device = input_ids.device
77
+ b, t = input_ids.size()
78
+
79
+ # Ensure the sequence length doesn't exceed the configured block size
80
+ assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
81
+
82
+ # If date is None, set a default value (e.g., 6 for all samples)
83
+ if date is None:
84
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
85
+ elif isinstance(date, int):
86
+ # If date is an integer, set it for all samples in the batch
87
+ date = (date - 2013) // 2 + 1
88
+ date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
89
+ elif isinstance(date, torch.Tensor):
90
+ # Ensure the tensor has the correct shape (batch_size,)
91
+ assert date.size(0) == b, "The size of date tensor must match the batch size."
92
+ date = date.to(device)
93
+
94
+ # Get outputs from each expert
95
+ expert_outputs = []
96
+ expert_losses = []
97
+
98
+ # Track the number of active experts for each sample in the batch
99
+ active_experts_count = torch.zeros(b, dtype=torch.long, device=device)
100
+
101
+ # Pass input through each expert
102
+ with torch.no_grad():
103
+ for i, expert in enumerate(self.experts):
104
+ # Masking logic based on date (for each sample in the batch)
105
+ expert_mask = date >= i # Mask experts where date < i (i.e., deactivate them)
106
+ #expert_mask = date <= i
107
+ # Expand the expert_mask to match the logits shape (batch_size, 1, 1)
108
+ expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float()
109
+
110
+ expert_output = expert(input_ids, targets=targets, date=date, get_logits=True, **kwargs)
111
+
112
+ logits = expert_output["logits"]
113
+ loss_to_log = expert_output["loss_to_log"]
114
+
115
+ # Mask out the outputs for deactivated experts
116
+ logits = logits * expert_mask_expanded # Apply the mask (zero out logits for inactive experts)
117
+
118
+ # Only append logits from active experts
119
+ expert_outputs.append(logits)
120
+ expert_losses.append(loss_to_log)
121
+
122
+ # Update active expert count for each sample
123
+ active_experts_count += expert_mask.long() # Ensure type consistency by converting `expert_mask` to Long
124
+
125
+ # Stack the logits and calculate the mean for each sample across the active experts
126
+ expert_outputs = torch.stack(expert_outputs, dim=0) # Shape: (num_experts, batch_size, seq_len, vocab_size)
127
+
128
+ # Convert logits to log-probabilities for each expert
129
+ log_probs = F.log_softmax(expert_outputs, dim=-1)
130
+
131
+ if self.use_router:
132
+ hidden = self.experts[0].transformer.wte(input_ids) # (B, T, D)
133
+ pooled_hidden = hidden.mean(dim=1) # (B, D)
134
+ router_logits = self.router(pooled_hidden) # (B, E)
135
+
136
+ expert_ids = torch.arange(self.num_experts, device=input_ids.device)
137
+ router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) # (B, E)
138
+ masked_router_logits = router_logits.masked_fill(~router_mask, float("-inf"))
139
+
140
+ # Select top-k
141
+ topk_probs, topk_indices = torch.topk(F.softmax(masked_router_logits, dim=-1), self.top_k, dim=-1)
142
+ sparse_probs = torch.zeros_like(router_logits)
143
+ sparse_probs.scatter_(1, topk_indices, topk_probs)
144
+ sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True)
145
+
146
+ # Convert weights to log-space
147
+ log_weights = torch.log(sparse_probs + 1e-9) # (B, E)
148
+
149
+ # Broadcast for logsumexp: (E, B, T, V)
150
+ log_weights_exp = log_weights.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) # (E, B, 1, 1)
151
+ weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
152
+
153
+ combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
154
+
155
+ else:
156
+ # Unweighted average in log-prob space across active experts (equal weights)
157
+ log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
158
+ weighted_log_probs = log_probs + log_weights
159
+ combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
160
+
161
+ # Calculate the loss if targets are provided
162
+ if targets is not None:
163
+ loss = F.nll_loss(combined_log_probs.view(-1, combined_log_probs.size(-1)), targets.view(-1), ignore_index=-1)
164
+ loss_to_log = loss.item()
165
+
166
+ # Add auxiliary router losses (only if routing is used and we're training)
167
+ if self.use_router and self.training:
168
+ flat_router_logits = router_logits.view(-1, router_logits.size(-1)) # (B*T, E)
169
+ flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
170
+
171
+ # Compute each auxiliary loss
172
+ entropy = entropy_reg(flat_router_logits)
173
+ lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts)
174
+ zloss = router_z_loss(flat_router_logits)
175
+
176
+ # Combine them with your preferred weights
177
+ loss = (
178
+ loss
179
+ + 0.01 *entropy
180
+ + 0.01 * lb_loss
181
+ + 0.0001 * zloss
182
+ )
183
+ else:
184
+ loss = None
185
+ loss_to_log = None
186
+
187
+ return Output(
188
+ logits=expert_outputs,
189
+ loss=loss,
190
+ combined_log_probs=combined_log_probs,
191
+ loss_to_log=loss_to_log,
192
+ expert_losses=expert_losses,
193
+ router_logits=router_logits if self.use_router else None,
194
+ selected_experts=topk_indices if self.use_router else None,
195
+ )
196
+
197
+
198
+ @torch.no_grad()
199
+ def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None):
200
+ """
201
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
202
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
203
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
204
+ """
205
+ idx = input_ids
206
+ for _ in range(max_new_tokens):
207
+ # if the sequence context is growing too long we must crop it at sequence_length
208
+ idx_cond = (
209
+ idx
210
+ if idx.size(1) <= self.config.sequence_length
211
+ else idx[:, -self.config.sequence_length :]
212
+ )
213
+ # forward the model to get the logits for the index in the sequence
214
+ logits = self(idx_cond, date, get_logits=True).logits
215
+ # pluck the logits at the final step and scale by desired temperature
216
+ logits = logits[:, -1, :] / temperature
217
+ # optionally crop the logits to only the top k options
218
+ if top_k is not None:
219
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
220
+ logits[logits < v[:, [-1]]] = -float("Inf")
221
+ # apply softmax to convert logits to (normalized) probabilities
222
+ probs = F.softmax(logits, dim=-1)
223
+ # sample from the distribution
224
+ idx_next = torch.multinomial(probs, num_samples=1)
225
+ # append sampled index to the running sequence and continue
226
+ idx = torch.cat((idx, idx_next), dim=1)
227
+ # check if we hit the end of the sequence
228
+ if idx_next.item() == 50526:
229
+ break
230
+
231
+ return idx
232
+
233
+ @torch.no_grad()
234
+ def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None):
235
+ idx = (
236
+ torch.tensor(
237
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
238
+ )
239
+ .view(1, -1)
240
+ .to(self.lm_head.weight.device)
241
+ )
242
+ out_idx = (
243
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
244
+ .view(-1)
245
+ .to("cpu")
246
+ .numpy()
247
+ )
248
+ return self.tokenizer.decode(out_idx)
249
+