monarch-bert-base-mnli / modeling_monarch_bert.py
ykae's picture
Upload 7 files
2759894 verified
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEncoder, BertLayer
# --- 1. Monarch Low-Level Operations ---
class MonarchUp(nn.Module):
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
super().__init__()
self.n_blocks = n_blocks
self.in_block = d_model // n_blocks
self.out_block_exp = hidden_dim // self.in_block
self.b1 = nn.Parameter(torch.randn(n_blocks, self.in_block, self.in_block) * 0.02)
self.b2 = nn.Parameter(torch.randn(self.in_block, self.out_block_exp, n_blocks) * 0.02)
self.bias = nn.Parameter(torch.zeros(hidden_dim))
def forward(self, x):
B, S, D = x.shape
x = x.view(B, S, self.n_blocks, self.in_block)
x = torch.einsum('bsni,noi->bsno', x, self.b1)
x = x.transpose(-1, -2) # Implicit fusion friendly for Triton later
x = torch.einsum('bsni,noi->bsno', x, self.b2)
return x.reshape(B, S, -1) + self.bias
class MonarchDown(nn.Module):
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
super().__init__()
self.n_blocks = n_blocks
self.out_block = d_model // n_blocks
self.in_block_exp = hidden_dim // self.out_block
self.b1 = nn.Parameter(torch.randn(self.out_block, n_blocks, self.in_block_exp) * 0.02)
self.b2 = nn.Parameter(torch.randn(n_blocks, self.out_block, self.out_block) * 0.02)
self.bias = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
B, S, D = x.shape
x = x.view(B, S, self.out_block, self.in_block_exp)
x = torch.einsum('bsni,noi->bsno', x, self.b1)
x = x.transpose(-1, -2)
x = torch.einsum('bsni,noi->bsno', x, self.b2)
return x.reshape(B, S, -1) + self.bias
class MonarchFFN(nn.Module):
def __init__(self, d_model, hidden_dim, groups, act_fn):
super().__init__()
self.monarch_up = MonarchUp(d_model, hidden_dim, n_blocks=groups)
self.act = act_fn
self.monarch_down = MonarchDown(d_model, hidden_dim, n_blocks=groups)
def forward(self, x):
x = self.monarch_up(x)
x = self.act(x)
x = self.monarch_down(x)
return x
class FFNWrapper(nn.Module):
def __init__(self, new_ffn_module):
super().__init__()
self.ffn = new_ffn_module
def forward(self, hidden_states):
return self.ffn(hidden_states)
# --- 2. The Model Architecture ---
class MonarchBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# 1. Load Standard BERT
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 2. Inject Monarch Layers based on Config
# This reconstructs the architecture exactly as it was during distillation
monarch_start_layer = getattr(config, "monarch_start_layer", 0)
n_groups = getattr(config, "monarch_groups", 16)
bert_layers = self.bert.encoder.layer
# Backward replacement logic (11 -> start_layer)
for i in range(11, monarch_start_layer - 1, -1):
d_model = config.hidden_size
h_dim = config.intermediate_size
# Create Monarch Module
monarch_ffn = MonarchFFN(d_model, h_dim, n_groups, nn.GELU())
# Surgery
bert_layers[i].intermediate = FFNWrapper(monarch_ffn)
bert_layers[i].output.dense = nn.Identity()
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return type('SequenceClassifierOutput', (object,), {'loss': loss, 'logits': logits})()