robinfaro commited on
Commit
38c0af3
·
verified ·
1 Parent(s): 449836c

Adding files from hf_modeling_btm_log_prob_mixing

Browse files
Files changed (1) hide show
  1. modeling.py +251 -0
modeling.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration import MoLMConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from transformers.utils import ModelOutput
7
+ from .gpt import GPTBase
8
+ from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
9
+ from typing import Optional, List
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class Output(ModelOutput):
15
+ logits: torch.FloatTensor = None
16
+ loss: Optional[torch.FloatTensor] = None
17
+ expert_losses: Optional[List] = None
18
+ loss_to_log: Optional[float] = None
19
+ router_logits: Optional[torch.FloatTensor] = None
20
+ selected_experts: Optional[torch.LongTensor] = None
21
+
22
+
23
+ class MoLM(PreTrainedModel):
24
+ config_class = MoLMConfig
25
+
26
+ def __init__(self, config, expert_weights=None, dropout=0.1):
27
+ """
28
+ Constructor for the MoLM (Mixture of Language Models) class.
29
+
30
+ :param config: The configuration of the model (should be a PretrainedConfig object)
31
+ :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts)
32
+ :param dropout: Dropout rate for the model
33
+ :param use_router: Flag to indicate whether to use routing (currently not implemented)
34
+ """
35
+ super(MoLM, self).__init__(config)
36
+
37
+ # Number of experts
38
+ self.num_experts = config.num_experts
39
+ print(f"Number of experts: {self.num_experts}")
40
+ print(f"Expert configurations: {config.expert_configs}")
41
+ assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
42
+ self.expert_configs = config.expert_configs
43
+
44
+
45
+ self.use_router = config.use_router
46
+
47
+ self.router = nn.Sequential(
48
+ nn.Linear(config.n_embd, self.num_experts),
49
+ )
50
+ self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts
51
+
52
+ # Initialize experts using the provided configurations
53
+ self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
54
+
55
+ # Load pre-trained weights if provided
56
+ if expert_weights is not None:
57
+ for i, expert in enumerate(self.experts):
58
+ expert.load_state_dict(expert_weights[i], strict=False)
59
+ expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone())
60
+ for param in expert.parameters():
61
+ param.requires_grad = False
62
+
63
+ def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs):
64
+ """
65
+ Forward pass for the MoLM model, passing input through all experts and averaging their outputs.
66
+
67
+ :param input_ids: Input token IDs (batch_size, seq_len)
68
+ :param attention_mask: Attention mask (batch_size, seq_len)
69
+ :param targets: Target labels for calculating loss (batch_size, seq_len)
70
+ :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date.
71
+ :param masking_enabled: Whether or not to perform expert masking (True/False)
72
+ :param kwargs: Additional arguments
73
+ :return: The averaged output of all active experts up to the specified date for each sample in the batch
74
+ """
75
+ device = input_ids.device
76
+ b, t = input_ids.size()
77
+
78
+ # Ensure the sequence length doesn't exceed the configured block size
79
+ assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
80
+
81
+ # If date is None, set a default value (e.g., 6 for all samples)
82
+ if date is None:
83
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
84
+ elif isinstance(date, int):
85
+ # If date is an integer, set it for all samples in the batch
86
+ date = (date - 2013) // 2 + 1
87
+ date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
88
+ elif isinstance(date, torch.Tensor):
89
+ # Ensure the tensor has the correct shape (batch_size,)
90
+ assert date.size(0) == b, "The size of date tensor must match the batch size."
91
+ date = date.to(device)
92
+
93
+ # Get outputs from each expert
94
+ expert_outputs = []
95
+ expert_losses = []
96
+
97
+ # Track the number of active experts for each sample in the batch
98
+ active_experts_count = torch.zeros(b, dtype=torch.long, device=device)
99
+
100
+ # Pass input through each expert
101
+ with torch.no_grad():
102
+ for i, expert in enumerate(self.experts):
103
+ # Masking logic based on date (for each sample in the batch)
104
+ expert_mask = date >= i # Mask experts where date < i (i.e., deactivate them)
105
+ #expert_mask = date <= i
106
+ # Expand the expert_mask to match the logits shape (batch_size, 1, 1)
107
+ expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float()
108
+
109
+ expert_output = expert(input_ids, targets=targets, date=date, get_logits=True, **kwargs)
110
+
111
+ logits = expert_output["logits"]
112
+ loss_to_log = expert_output["loss_to_log"]
113
+
114
+ # Mask out the outputs for deactivated experts
115
+ logits = logits * expert_mask_expanded # Apply the mask (zero out logits for inactive experts)
116
+
117
+ # Only append logits from active experts
118
+ expert_outputs.append(logits)
119
+ expert_losses.append(loss_to_log)
120
+
121
+ # Update active expert count for each sample
122
+ active_experts_count += expert_mask.long() # Ensure type consistency by converting `expert_mask` to Long
123
+
124
+ # Stack the logits and calculate the mean for each sample across the active experts
125
+ expert_outputs = torch.stack(expert_outputs, dim=0) # Shape: (num_experts, batch_size, seq_len, vocab_size)
126
+
127
+ # Convert logits to log-probabilities for each expert
128
+ log_probs = F.log_softmax(expert_outputs, dim=-1)
129
+
130
+ if self.use_router:
131
+ hidden = self.experts[0].transformer.wte(input_ids) # (B, T, D)
132
+ pooled_hidden = hidden.mean(dim=1) # (B, D)
133
+ router_logits = self.router(pooled_hidden) # (B, E)
134
+
135
+ expert_ids = torch.arange(self.num_experts, device=input_ids.device)
136
+ router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) # (B, E)
137
+ masked_router_logits = router_logits.masked_fill(~router_mask, float("-inf"))
138
+
139
+ # Select top-k
140
+ topk_probs, topk_indices = torch.topk(F.softmax(masked_router_logits, dim=-1), self.top_k, dim=-1)
141
+ sparse_probs = torch.zeros_like(router_logits)
142
+ sparse_probs.scatter_(1, topk_indices, topk_probs)
143
+ sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True)
144
+
145
+ # Convert weights to log-space
146
+ log_weights = torch.log(sparse_probs + 1e-9) # (B, E)
147
+
148
+ # Broadcast for logsumexp: (E, B, T, V)
149
+ log_weights_exp = log_weights.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) # (E, B, 1, 1)
150
+ weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
151
+
152
+ combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
153
+ combined_logits = combined_log_probs # because loss works with log-probs if used properly
154
+
155
+ else:
156
+ # Unweighted average in log-prob space across active experts (equal weights)
157
+ log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
158
+ weighted_log_probs = log_probs + log_weights
159
+ combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
160
+ combined_logits = combined_log_probs # because loss works with log-probs if used properly
161
+
162
+ # Calculate the loss if targets are provided
163
+ if targets is not None:
164
+ #loss = F.cross_entropy(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
165
+ loss = F.nll_loss(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
166
+ loss_to_log = loss.item()
167
+
168
+ # Add auxiliary router losses (only if routing is used and we're training)
169
+ if self.use_router and self.training:
170
+ flat_router_logits = router_logits.view(-1, router_logits.size(-1)) # (B*T, E)
171
+ flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
172
+
173
+ # Compute each auxiliary loss
174
+ entropy = entropy_reg(flat_router_logits)
175
+ lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts)
176
+ zloss = router_z_loss(flat_router_logits)
177
+
178
+ # Combine them with your preferred weights
179
+ loss = (
180
+ loss
181
+ + 0.01 *entropy
182
+ + 0.01 * lb_loss
183
+ + 0.0001 * zloss
184
+ )
185
+ else:
186
+ loss = None
187
+ loss_to_log = None
188
+
189
+ return Output(
190
+ logits=combined_logits,
191
+ loss=loss,
192
+ combined_log_probs=combined_log_probs,
193
+ loss_to_log=loss_to_log,
194
+ expert_losses=expert_losses,
195
+ router_logits=router_logits if self.use_router else None,
196
+ selected_experts=topk_indices if self.use_router else None,
197
+ )
198
+
199
+
200
+ @torch.no_grad()
201
+ def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None):
202
+ """
203
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
204
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
205
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
206
+ """
207
+ idx = input_ids
208
+ for _ in range(max_new_tokens):
209
+ # if the sequence context is growing too long we must crop it at sequence_length
210
+ idx_cond = (
211
+ idx
212
+ if idx.size(1) <= self.config.sequence_length
213
+ else idx[:, -self.config.sequence_length :]
214
+ )
215
+ # forward the model to get the logits for the index in the sequence
216
+ logits = self(idx_cond, date, get_logits=True).logits
217
+ # pluck the logits at the final step and scale by desired temperature
218
+ logits = logits[:, -1, :] / temperature
219
+ # optionally crop the logits to only the top k options
220
+ if top_k is not None:
221
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
222
+ logits[logits < v[:, [-1]]] = -float("Inf")
223
+ # apply softmax to convert logits to (normalized) probabilities
224
+ probs = F.softmax(logits, dim=-1)
225
+ # sample from the distribution
226
+ idx_next = torch.multinomial(probs, num_samples=1)
227
+ # append sampled index to the running sequence and continue
228
+ idx = torch.cat((idx, idx_next), dim=1)
229
+ # check if we hit the end of the sequence
230
+ if idx_next.item() == 50526:
231
+ break
232
+
233
+ return idx
234
+
235
+ @torch.no_grad()
236
+ def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None):
237
+ idx = (
238
+ torch.tensor(
239
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
240
+ )
241
+ .view(1, -1)
242
+ .to(self.lm_head.weight.device)
243
+ )
244
+ out_idx = (
245
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
246
+ .view(-1)
247
+ .to("cpu")
248
+ .numpy()
249
+ )
250
+ return self.tokenizer.decode(out_idx)
251
+