robinfaro's picture
Adding modeling.py file
7e5746b verified
from transformers import PreTrainedModel
from .configuration import MoLMConfig
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.utils import ModelOutput
from .gpt import GPTBase
from .aux_losses import entropy_reg, load_balancing_loss, router_z_loss
from typing import Optional, List
from dataclasses import dataclass
import tiktoken
@dataclass
class Output(ModelOutput):
logits: torch.FloatTensor = None
loss: Optional[torch.FloatTensor] = None
expert_losses: Optional[List] = None
loss_to_log: Optional[float] = None
router_logits: Optional[torch.FloatTensor] = None
selected_experts: Optional[torch.LongTensor] = None
combined_log_probs: Optional[torch.FloatTensor] = None
class MoLM(PreTrainedModel):
config_class = MoLMConfig
def __init__(self, config, expert_weights=None, dropout=0.1):
"""
Constructor for the MoLM (Mixture of Language Models) class.
:param config: The configuration of the model (should be a PretrainedConfig object)
:param expert_weights: (Optional) A list of weights for each expert to load pre-trained weights (should match the number of experts)
:param dropout: Dropout rate for the model
:param use_router: Flag to indicate whether to use routing (currently not implemented)
"""
super(MoLM, self).__init__(config)
# Number of experts
self.num_experts = config.num_experts
# print(f"Number of experts: {self.num_experts}")
# print(f"Expert configurations: {config.expert_configs}")
assert len(config.expert_configs) == self.num_experts, "Number of expert configurations must match num_experts in config."
self.expert_configs = config.expert_configs
self.use_router = config.use_router
self.router = nn.Sequential(
nn.Linear(config.n_embd, self.num_experts),
)
self.top_k = config.top_k_experts if hasattr(config, "top_k_experts") else self.num_experts
# Initialize experts using the provided configurations
self.experts = nn.ModuleList([GPTBase(config=self.expert_configs[i]) for i in range(self.num_experts)])
self.tokenizer = tiktoken.get_encoding("gpt2")
# Load pre-trained weights if provided
if expert_weights is not None:
for i, expert in enumerate(self.experts):
expert.load_state_dict(expert_weights[i], strict=False)
expert.transformer.wte.weight = torch.nn.Parameter(expert.transformer.wte.weight.clone())
for param in expert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask=None, targets=None, date=None, masking_enabled=True, **kwargs):
"""
Forward pass for the MoLM model, passing input through all experts and averaging their outputs.
:param input_ids: Input token IDs (batch_size, seq_len)
:param attention_mask: Attention mask (batch_size, seq_len)
:param targets: Target labels for calculating loss (batch_size, seq_len)
:param date: A tensor indicating which experts to use. Each sample in the batch can have a different date.
:param masking_enabled: Whether or not to perform expert masking (True/False)
:param kwargs: Additional arguments
:return: The averaged output of all active experts up to the specified date for each sample in the batch
"""
device = input_ids.device
b, t = input_ids.size()
# Ensure the sequence length doesn't exceed the configured block size
assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
# If date is None, set a default value (e.g., 6 for all samples)
if date is None:
date = torch.full((1, b), 6, dtype=torch.long, device=device).squeeze(0)
elif isinstance(date, int):
# If date is an integer, set it for all samples in the batch
date = (date - 2013) // 2 + 1
date = torch.full((1, b), date, dtype=torch.long, device=device).squeeze(0)
elif isinstance(date, torch.Tensor):
# Ensure the tensor has the correct shape (batch_size,)
assert date.size(0) == b, "The size of date tensor must match the batch size."
date = date.to(device)
# Get outputs from each expert
expert_outputs = []
expert_losses = []
# Track the number of active experts for each sample in the batch
active_experts_count = torch.zeros(b, dtype=torch.long, device=device)
# Pass input through each expert
with torch.no_grad():
for i, expert in enumerate(self.experts):
# Masking logic based on date (for each sample in the batch)
expert_mask = date >= i # Mask experts where date < i (i.e., deactivate them)
#expert_mask = date <= i
# Expand the expert_mask to match the logits shape (batch_size, 1, 1)
expert_mask_expanded = expert_mask.unsqueeze(-1).unsqueeze(-1).float()
expert_output = expert(input_ids, targets=targets, date=date, **kwargs, get_logits=True)
logits = expert_output["logits"]
loss_to_log = expert_output["loss_to_log"]
# Mask out the outputs for deactivated experts
logits = logits * expert_mask_expanded # Apply the mask (zero out logits for inactive experts)
# Only append logits from active experts
expert_outputs.append(logits)
expert_losses.append(loss_to_log)
# Update active expert count for each sample
active_experts_count += expert_mask.long() # Ensure type consistency by converting `expert_mask` to Long
# Stack the logits and calculate the mean for each sample across the active experts
expert_outputs = torch.stack(expert_outputs, dim=0) # Shape: (num_experts, batch_size, seq_len, vocab_size)
# Convert logits to log-probabilities for each expert
log_probs = F.log_softmax(expert_outputs, dim=-1)
if self.use_router:
hidden = self.experts[0].transformer.wte(input_ids) # (B, T, D)
pooled_hidden = hidden.mean(dim=1) # (B, D)
router_logits = self.router(pooled_hidden) # (B, E)
expert_ids = torch.arange(self.num_experts, device=input_ids.device)
router_mask = date.unsqueeze(1) >= expert_ids.unsqueeze(0) # (B, E)
masked_router_logits = router_logits.masked_fill(~router_mask, float("-inf"))
# Select top-k
topk_probs, topk_indices = torch.topk(F.softmax(masked_router_logits, dim=-1), self.top_k, dim=-1)
sparse_probs = torch.zeros_like(router_logits)
sparse_probs.scatter_(1, topk_indices, topk_probs)
sparse_probs = sparse_probs / sparse_probs.sum(dim=1, keepdim=True)
# Convert weights to log-space
log_weights = torch.log(sparse_probs + 1e-9) # (B, E)
# Broadcast for logsumexp: (E, B, T, V)
log_weights_exp = log_weights.transpose(0, 1).unsqueeze(-1).unsqueeze(-1) # (E, B, 1, 1)
weighted_log_probs = log_probs + log_weights_exp # (E, B, T, V)
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
else:
# Unweighted average in log-prob space across active experts (equal weights)
log_weights = torch.log(1.0 / active_experts_count.float().clamp(min=1.0)).view(1, -1, 1, 1) # (1, B, 1, 1)
weighted_log_probs = log_probs + log_weights
combined_log_probs = torch.logsumexp(weighted_log_probs, dim=0) # (B, T, V)
# Calculate the loss if targets are provided
if targets is not None:
loss = F.nll_loss(combined_log_probs.view(-1, combined_log_probs.size(-1)), targets.view(-1), ignore_index=-1)
loss_to_log = loss.item()
# Add auxiliary router losses (only if routing is used and we're training)
if self.use_router and self.training:
flat_router_logits = router_logits.view(-1, router_logits.size(-1)) # (B*T, E)
flat_selected_experts = topk_indices.view(-1, topk_indices.size(-1)) # (B*T, top_k)
# Compute each auxiliary loss
entropy = entropy_reg(flat_router_logits)
lb_loss = load_balancing_loss(flat_router_logits, flat_selected_experts)
zloss = router_z_loss(flat_router_logits)
# Combine them with your preferred weights
loss = (
loss
+ 0.01 *entropy
+ 0.01 * lb_loss
+ 0.0001 * zloss
)
else:
loss = None
loss_to_log = None
return Output(
logits=expert_outputs,
loss=loss,
combined_log_probs=combined_log_probs,
loss_to_log=loss_to_log,
expert_losses=expert_losses,
router_logits=router_logits if self.use_router else None,
selected_experts=topk_indices if self.use_router else None,
)
@torch.no_grad()
def generate(self, input_ids, max_new_tokens, date=None, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
idx = input_ids
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at sequence_length
idx_cond = (
idx
if idx.size(1) <= self.config.sequence_length
else idx[:, -self.config.sequence_length :]
)
# # forward the model to get the logits for the index in the sequence
# logits = self(idx_cond, date, get_logits=True).logits
# # pluck the logits at the final step and scale by desired temperature
# logits = logits[:, -1, :] / temperature
# # optionally crop the logits to only the top k options
# if top_k is not None:
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
# logits[logits < v[:, [-1]]] = -float("Inf")
# # apply softmax to convert logits to (normalized) probabilities
# probs = F.softmax(logits, dim=-1)
# # sample from the distribution
log_probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
#idx_next = torch.multinomial(probs, num_samples=1)
# Sample from the log probabilities
if temperature == 0:
# If temperature is 0, take the argmax (greedy sampling)
idx_next = torch.argmax(log_probs, dim=-1, keepdim=True)
else:
# Apply temperature scaling
scaled_log_probs = log_probs / temperature
# Convert log probabilities to probabilities
probs = torch.exp(scaled_log_probs)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
# check if we hit the end of the sequence
if idx_next.item() == 50526:
break
return idx
@torch.no_grad()
def generate_from_string(self, in_str, max_new_tokens, date=None, temperature=1.0, top_k=None):
idx = (
torch.tensor(
self.tokenizer.encode(in_str)
)
.view(1, -1)
.to(next(self.parameters()).device)
)
out_idx = (
self.generate(idx, max_new_tokens, date, temperature, top_k)
.view(-1)
.to("cpu")
.numpy()
)
return self.tokenizer.decode(out_idx)