robinfaro commited on
Commit
1e61966
·
verified ·
1 Parent(s): e54881a

Adding files from hf_modeling_btm_reversed

Browse files
Files changed (1) hide show
  1. modeling.py +257 -0
modeling.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .configuration import MoLMConfig
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from transformers.utils import ModelOutput
7
+ from .gpt import GPTBase
8
+ from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
9
+ from typing import Optional, List
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class Output(ModelOutput):
15
+ logits: torch.FloatTensor = None
16
+ loss: Optional[torch.FloatTensor] = None
17
+ expert_losses: Optional[List] = None
18
+ loss_to_log: Optional[float] = None
19
+ router_logits: Optional[torch.FloatTensor] = None
20
+ selected_experts: Optional[torch.LongTensor] = None
21
+
22
+
23
+ class MoLM(PreTrainedModel):
24
+ config_class = MoLMConfig
25
+
26
+ def __init__(self, config, expert_weights=None, dropout=0.1):
27
+ """
28
+ Constructor for the MoLM (Mixture of Language Models) class.
29
+
30
+ :param config: The configuration of the model (should be a PretrainedConfig object)
31
+ :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts)
32
+ :param dropout: Dropout rate for the model
33
+ :param use_router: Flag to indicate whether to use routing (currently not implemented)
34
+ """
35
+ super(MoLM, self).__init__(config)
36
+
37
+ # Number of experts
38
+ self.num_experts = config.num_experts
39
+ print(f"Number of experts: {self.num_experts}")
40
+ print(f"Expert configurations: {config.expert_configs}")
41
+ assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
42
+ self.expert_configs = config.expert_configs
43
+
44
+
45
+ self.use_router = config.use_router
46
+
47
+ self.router = nn.Sequential(
48
+ nn.Linear(config.n_embd, self.num_experts),
49
+ )
50
+ self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts
51
+
52
+ # Initialize experts using the provided configurations
53
+ self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
54
+
55
+ # Load pre-trained weights if provided
56
+ if expert_weights is not None:
57
+ for i, expert in enumerate(self.experts):
58
+ expert.load_state_dict(expert_weights[i], strict=False)
59
+ expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone())
60
+ for param in expert.parameters():
61
+ param.requires_grad = False
62
+
63
+ def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs):
64
+ """
65
+ Forward pass for the MoLM model, passing input through all experts and averaging their outputs.
66
+
67
+ :param input_ids: Input token IDs (batch_size, seq_len)
68
+ :param attention_mask: Attention mask (batch_size, seq_len)
69
+ :param targets: Target labels for calculating loss (batch_size, seq_len)
70
+ :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date.
71
+ :param masking_enabled: Whether or not to perform expert masking (True/False)
72
+ :param kwargs: Additional arguments
73
+ :return: The averaged output of all active experts up to the specified date for each sample in the batch
74
+ """
75
+ device = input_ids.device
76
+ b, t = input_ids.size()
77
+
78
+ # Ensure the sequence length doesn't exceed the configured block size
79
+ assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
80
+
81
+ # If date is None, set a default value (e.g., 6 for all samples)
82
+ if date is None:
83
+ date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
84
+ elif isinstance(date, int):
85
+ # If date is an integer, set it for all samples in the batch
86
+ date = (date - 2013) // 2 + 1
87
+ date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
88
+ elif isinstance(date, torch.Tensor):
89
+ # Ensure the tensor has the correct shape (batch_size,)
90
+ assert date.size(0) == b, "The size of date tensor must match the batch size."
91
+ date = date.to(device)
92
+
93
+ # Get outputs from each expert
94
+ expert_outputs = []
95
+ expert_losses = []
96
+
97
+ # Track the number of active experts for each sample in the batch
98
+ active_experts_count = torch.zeros(b, dtype=torch.long, device=device)
99
+
100
+ # Pass input through each expert
101
+ with torch.no_grad():
102
+ for i, expert in enumerate(self.experts):
103
+ # Masking logic based on date (for each sample in the batch)
104
+ #expert_mask = date >= i # Mask experts where date < i (i.e., deactivate them)
105
+ expert_mask = date <= i
106
+ # Expand the expert_mask to match the logits shape (batch_size, 1, 1)
107
+ expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float()
108
+
109
+ expert_output = expert(input_ids, targets=targets, date=date, get_logits=True, **kwargs)
110
+
111
+ logits = expert_output["logits"]
112
+ loss_to_log = expert_output["loss_to_log"]
113
+
114
+ # Mask out the outputs for deactivated experts
115
+ logits = logits * expert_mask_expanded # Apply the mask (zero out logits for inactive experts)
116
+
117
+ # Only append logits from active experts
118
+ expert_outputs.append(logits)
119
+ expert_losses.append(loss_to_log)
120
+
121
+ # Update active expert count for each sample
122
+ active_experts_count += expert_mask.long() # Ensure type consistency by converting `expert_mask` to Long
123
+
124
+ # Stack the logits and calculate the mean for each sample across the active experts
125
+ expert_outputs = torch.stack(expert_outputs, dim=0) # Shape: (num_experts, batch_size, seq_len, vocab_size)
126
+
127
+ if self.use_router:
128
+ hidden = self.experts[0].transformer.wte(input_ids) # shape (B, T, D)
129
+ pooled_hidden = hidden.mean(dim=1) # shape (B, D)
130
+ router_logits = self.router(pooled_hidden) # shape (B, num_experts)
131
+
132
+ #create router mask from date (B, num_experts)
133
+ #expert i is allowed if date[b] >= i
134
+ expert_ids = torch.arange(self.num_experts, device=input_ids.device)
135
+ router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) # (B, num_experts)
136
+
137
+ # Mask out inactive experts by setting logits to -inf
138
+ masked_logits = router_logits.masked_fill(~router_mask, float("-inf"))
139
+ #masked_logits = router_logits * router_mask.float() # (B, num_experts)
140
+ router_probs = F.softmax(masked_logits, dim=-1) # shape (B, num_experts)
141
+ #top-k selection
142
+ topk_probs, topk_indices = torch.topk(router_probs, self.top_k, dim=-1) # (B, top_k)
143
+ sparse_probs = torch.zeros_like(router_probs) # (B, num_experts)
144
+ sparse_probs.scatter_(1, topk_indices, topk_probs) # only top-k kept
145
+ #renormalize
146
+ sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True) # (B, num_experts)
147
+ #router_probs_t = sparse_probs.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) # (num_experts, B, 1, 1)
148
+
149
+ # Apply weighted sum
150
+ #weighted_logits = (expert_outputs * router_probs_t).sum(dim=0) # (B, T, V)
151
+
152
+ #we are doing this to save memory
153
+ #instead of holding full (E, B, T, V) in memory, accumulate weighted logits
154
+ weighted_logits = None
155
+ for i in range(self.num_experts):
156
+ weight = sparse_probs[:, i].view(b, 1, 1) # shape: (B, 1, 1)
157
+ contrib = expert_outputs[i] * weight # shape: (B, T, V)
158
+ if weighted_logits is None:
159
+ weighted_logits = contrib
160
+ else:
161
+ weighted_logits += contrib
162
+ combined_logits = weighted_logits # (B, T, V)
163
+
164
+ #combined_logits = weighted_logits
165
+ else:
166
+ # Calculate the sum across the active experts for each sample and then average
167
+ summed_logits = torch.sum(expert_outputs, dim=0) # Sum across active experts
168
+ combined_logits = summed_logits / active_experts_count.unsqueeze(-1).unsqueeze(-1) # Divide by the number of active experts
169
+
170
+ # Calculate the loss if targets are provided
171
+ if targets is not None:
172
+ loss = F.cross_entropy(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1)
173
+ loss_to_log = loss.item()
174
+
175
+ # Add auxiliary router losses (only if routing is used and we're training)
176
+ if self.use_router and self.training:
177
+ flat_router_logits = router_logits.view(-1, router_logits.size(-1)) # (B*T, E)
178
+ flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
179
+
180
+ # Compute each auxiliary loss
181
+ entropy = entropy_reg(flat_router_logits)
182
+ lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts)
183
+ zloss = router_z_loss(flat_router_logits)
184
+
185
+ # Combine them with your preferred weights
186
+ loss = (
187
+ loss
188
+ + 0.01 *entropy
189
+ + 0.01 * lb_loss
190
+ + 0.0001 * zloss
191
+ )
192
+ else:
193
+ loss = None
194
+ loss_to_log = None
195
+
196
+ return Output(
197
+ logits=combined_logits,
198
+ loss=loss,
199
+ loss_to_log=loss_to_log,
200
+ expert_losses=expert_losses,
201
+ router_logits=router_logits if self.use_router else None,
202
+ selected_experts=topk_indices if self.use_router else None,
203
+ )
204
+
205
+
206
+ @torch.no_grad()
207
+ def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None):
208
+ """
209
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
210
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
211
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
212
+ """
213
+ idx = input_ids
214
+ for _ in range(max_new_tokens):
215
+ # if the sequence context is growing too long we must crop it at sequence_length
216
+ idx_cond = (
217
+ idx
218
+ if idx.size(1) <= self.config.sequence_length
219
+ else idx[:, -self.config.sequence_length :]
220
+ )
221
+ # forward the model to get the logits for the index in the sequence
222
+ logits = self(idx_cond, date, get_logits=True).logits
223
+ # pluck the logits at the final step and scale by desired temperature
224
+ logits = logits[:, -1, :] / temperature
225
+ # optionally crop the logits to only the top k options
226
+ if top_k is not None:
227
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
228
+ logits[logits < v[:, [-1]]] = -float("Inf")
229
+ # apply softmax to convert logits to (normalized) probabilities
230
+ probs = F.softmax(logits, dim=-1)
231
+ # sample from the distribution
232
+ idx_next = torch.multinomial(probs, num_samples=1)
233
+ # append sampled index to the running sequence and continue
234
+ idx = torch.cat((idx, idx_next), dim=1)
235
+ # check if we hit the end of the sequence
236
+ if idx_next.item() == 50526:
237
+ break
238
+
239
+ return idx
240
+
241
+ @torch.no_grad()
242
+ def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None):
243
+ idx = (
244
+ torch.tensor(
245
+ self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})
246
+ )
247
+ .view(1, -1)
248
+ .to(self.lm_head.weight.device)
249
+ )
250
+ out_idx = (
251
+ self.generate(idx, max_new_tokens, date, temperature, top_k)
252
+ .view(-1)
253
+ .to("cpu")
254
+ .numpy()
255
+ )
256
+ return self.tokenizer.decode(out_idx)
257
+