| | from transformers import PreTrainedModel |
| | from .configuration import MoLMConfig |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from transformers.utils import ModelOutput |
| | from .gpt import GPTBase |
| | from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss |
| | from typing import Optional, List |
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class Output(ModelOutput): |
| | logits: torch.FloatTensor = None |
| | loss: Optional[torch.FloatTensor] = None |
| | expert_losses: Optional[List] = None |
| | loss_to_log: Optional[float] = None |
| | router_logits: Optional[torch.FloatTensor] = None |
| | selected_experts: Optional[torch.LongTensor] = None |
| |
|
| |
|
| | class MoLM(PreTrainedModel): |
| | config_class = MoLMConfig |
| |
|
| | def __init__(self, config, expert_weights=None, dropout=0.1): |
| | """ |
| | Constructor for the MoLM (Mixture of Language Models) class. |
| | |
| | :param config: The configuration of the model (should be a PretrainedConfig object) |
| | :param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts) |
| | :param dropout: Dropout rate for the model |
| | :param use_router: Flag to indicate whether to use routing (currently not implemented) |
| | """ |
| | super(MoLM, self).__init__(config) |
| | |
| | |
| | self.num_experts = config.num_experts |
| | print(f"Number of experts: {self.num_experts}") |
| | print(f"Expert configurations: {config.expert_configs}") |
| | assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config." |
| | self.expert_configs = config.expert_configs |
| |
|
| | |
| | self.use_router = config.use_router |
| | |
| | self.router = nn.Sequential( |
| | nn.Linear(config.n_embd, self.num_experts), |
| | ) |
| | self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts |
| |
|
| | |
| | self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)]) |
| | |
| | |
| | if expert_weights is not None: |
| | for i, expert in enumerate(self.experts): |
| | expert.load_state_dict(expert_weights[i], strict=False) |
| | expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone()) |
| | for param in expert.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs): |
| | """ |
| | Forward pass for the MoLM model, passing input through all experts and averaging their outputs. |
| | |
| | :param input_ids: Input token IDs (batch_size, seq_len) |
| | :param attention_mask: Attention mask (batch_size, seq_len) |
| | :param targets: Target labels for calculating loss (batch_size, seq_len) |
| | :param date: A tensor indicating which experts to use. Each sample in the batch can have a different date. |
| | :param masking_enabled: Whether or not to perform expert masking (True/False) |
| | :param kwargs: Additional arguments |
| | :return: The averaged output of all active experts up to the specified date for each sample in the batch |
| | """ |
| | device = input_ids.device |
| | b, t = input_ids.size() |
| |
|
| | |
| | assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" |
| |
|
| | |
| | if date is None: |
| | date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0) |
| | elif isinstance(date, int): |
| | |
| | date = (date - 2013) // 2 + 1 |
| | date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0) |
| | elif isinstance(date, torch.Tensor): |
| | |
| | assert date.size(0) == b, "The size of date tensor must match the batch size." |
| | date = date.to(device) |
| |
|
| | |
| | expert_outputs = [] |
| | expert_losses = [] |
| |
|
| | |
| | active_experts_count = torch.zeros(b, dtype=torch.long, device=device) |
| |
|
| | |
| | with torch.no_grad(): |
| | for i, expert in enumerate(self.experts): |
| | |
| | |
| | expert_mask = date <= i |
| | |
| | expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float() |
| |
|
| | expert_output = expert(input_ids, targets=targets, date=date, get_logits=True, **kwargs) |
| |
|
| | logits = expert_output["logits"] |
| | loss_to_log = expert_output["loss_to_log"] |
| |
|
| | |
| | logits = logits * expert_mask_expanded |
| |
|
| | |
| | expert_outputs.append(logits) |
| | expert_losses.append(loss_to_log) |
| |
|
| | |
| | active_experts_count += expert_mask.long() |
| |
|
| | |
| | expert_outputs = torch.stack(expert_outputs, dim=0) |
| | |
| | if self.use_router: |
| | hidden = self.experts[0].transformer.wte(input_ids) |
| | pooled_hidden = hidden.mean(dim=1) |
| | router_logits = self.router(pooled_hidden) |
| |
|
| | |
| | |
| | expert_ids = torch.arange(self.num_experts, device=input_ids.device) |
| | router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) |
| |
|
| | |
| | masked_logits = router_logits.masked_fill(~router_mask, float("-inf")) |
| | |
| | router_probs = F.softmax(masked_logits, dim=-1) |
| | |
| | topk_probs, topk_indices = torch.topk(router_probs, self.top_k, dim=-1) |
| | sparse_probs = torch.zeros_like(router_probs) |
| | sparse_probs.scatter_(1, topk_indices, topk_probs) |
| | |
| | sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True) |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | weighted_logits = None |
| | for i in range(self.num_experts): |
| | weight = sparse_probs[:, i].view(b, 1, 1) |
| | contrib = expert_outputs[i] * weight |
| | if weighted_logits is None: |
| | weighted_logits = contrib |
| | else: |
| | weighted_logits += contrib |
| | combined_logits = weighted_logits |
| |
|
| | |
| | else: |
| | |
| | summed_logits = torch.sum(expert_outputs, dim=0) |
| | combined_logits = summed_logits / active_experts_count.unsqueeze(-1).unsqueeze(-1) |
| |
|
| | |
| | if targets is not None: |
| | loss = F.cross_entropy(combined_logits.view(-1, combined_logits.size(-1)), targets.view(-1), ignore_index=-1) |
| | loss_to_log = loss.item() |
| |
|
| | |
| | if self.use_router and self.training: |
| | flat_router_logits = router_logits.view(-1, router_logits.size(-1)) |
| | flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) |
| |
|
| | |
| | entropy = entropy_reg(flat_router_logits) |
| | lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts) |
| | zloss = router_z_loss(flat_router_logits) |
| |
|
| | |
| | loss = ( |
| | loss |
| | + 0.01 *entropy |
| | + 0.01 * lb_loss |
| | + 0.0001 * zloss |
| | ) |
| | else: |
| | loss = None |
| | loss_to_log = None |
| |
|
| | return Output( |
| | logits=combined_logits, |
| | loss=loss, |
| | loss_to_log=loss_to_log, |
| | expert_losses=expert_losses, |
| | router_logits=router_logits if self.use_router else None, |
| | selected_experts=topk_indices if self.use_router else None, |
| | ) |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None): |
| | """ |
| | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| | the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| | Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| | """ |
| | idx = input_ids |
| | for _ in range(max_new_tokens): |
| | |
| | idx_cond = ( |
| | idx |
| | if idx.size(1) <= self.config.sequence_length |
| | else idx[:, -self.config.sequence_length :] |
| | ) |
| | |
| | logits = self(idx_cond, date, get_logits=True).logits |
| | |
| | logits = logits[:, -1, :] / temperature |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float("Inf") |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| | |
| | if idx_next.item() == 50526: |
| | break |
| |
|
| | return idx |
| |
|
| | @torch.no_grad() |
| | def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None): |
| | idx = ( |
| | torch.tensor( |
| | self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"}) |
| | ) |
| | .view(1, -1) |
| | .to(self.lm_head.weight.device) |
| | ) |
| | out_idx = ( |
| | self.generate(idx, max_new_tokens, date, temperature, top_k) |
| | .view(-1) |
| | .to("cpu") |
| | .numpy() |
| | ) |
| | return self.tokenizer.decode(out_idx) |
| | |
| |
|