|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import BertModel, BertPreTrainedModel |
|
|
from transformers.models.bert.modeling_bert import BertEncoder, BertLayer |
|
|
|
|
|
|
|
|
class MonarchUp(nn.Module): |
|
|
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16): |
|
|
super().__init__() |
|
|
self.n_blocks = n_blocks |
|
|
self.in_block = d_model // n_blocks |
|
|
self.out_block_exp = hidden_dim // self.in_block |
|
|
|
|
|
self.b1 = nn.Parameter(torch.randn(n_blocks, self.in_block, self.in_block) * 0.02) |
|
|
self.b2 = nn.Parameter(torch.randn(self.in_block, self.out_block_exp, n_blocks) * 0.02) |
|
|
self.bias = nn.Parameter(torch.zeros(hidden_dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
B, S, D = x.shape |
|
|
x = x.view(B, S, self.n_blocks, self.in_block) |
|
|
x = torch.einsum('bsni,noi->bsno', x, self.b1) |
|
|
x = x.transpose(-1, -2) |
|
|
x = torch.einsum('bsni,noi->bsno', x, self.b2) |
|
|
return x.reshape(B, S, -1) + self.bias |
|
|
|
|
|
class MonarchDown(nn.Module): |
|
|
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16): |
|
|
super().__init__() |
|
|
self.n_blocks = n_blocks |
|
|
self.out_block = d_model // n_blocks |
|
|
self.in_block_exp = hidden_dim // self.out_block |
|
|
|
|
|
self.b1 = nn.Parameter(torch.randn(self.out_block, n_blocks, self.in_block_exp) * 0.02) |
|
|
self.b2 = nn.Parameter(torch.randn(n_blocks, self.out_block, self.out_block) * 0.02) |
|
|
self.bias = nn.Parameter(torch.zeros(d_model)) |
|
|
|
|
|
def forward(self, x): |
|
|
B, S, D = x.shape |
|
|
x = x.view(B, S, self.out_block, self.in_block_exp) |
|
|
x = torch.einsum('bsni,noi->bsno', x, self.b1) |
|
|
x = x.transpose(-1, -2) |
|
|
x = torch.einsum('bsni,noi->bsno', x, self.b2) |
|
|
return x.reshape(B, S, -1) + self.bias |
|
|
|
|
|
class MonarchFFN(nn.Module): |
|
|
def __init__(self, d_model, hidden_dim, groups, act_fn): |
|
|
super().__init__() |
|
|
self.monarch_up = MonarchUp(d_model, hidden_dim, n_blocks=groups) |
|
|
self.act = act_fn |
|
|
self.monarch_down = MonarchDown(d_model, hidden_dim, n_blocks=groups) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.monarch_up(x) |
|
|
x = self.act(x) |
|
|
x = self.monarch_down(x) |
|
|
return x |
|
|
|
|
|
class FFNWrapper(nn.Module): |
|
|
def __init__(self, new_ffn_module): |
|
|
super().__init__() |
|
|
self.ffn = new_ffn_module |
|
|
def forward(self, hidden_states): |
|
|
return self.ffn(hidden_states) |
|
|
|
|
|
|
|
|
class MonarchBertForSequenceClassification(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.bert = BertModel(config) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
|
|
|
monarch_start_layer = getattr(config, "monarch_start_layer", 0) |
|
|
n_groups = getattr(config, "monarch_groups", 16) |
|
|
|
|
|
bert_layers = self.bert.encoder.layer |
|
|
|
|
|
for i in range(11, monarch_start_layer - 1, -1): |
|
|
d_model = config.hidden_size |
|
|
h_dim = config.intermediate_size |
|
|
|
|
|
|
|
|
monarch_ffn = MonarchFFN(d_model, h_dim, n_groups, nn.GELU()) |
|
|
|
|
|
|
|
|
bert_layers[i].intermediate = FFNWrapper(monarch_ffn) |
|
|
bert_layers[i].output.dense = nn.Identity() |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): |
|
|
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
|
pooled_output = outputs[1] |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
return type('SequenceClassifierOutput', (object,), {'loss': loss, 'logits': logits})() |
|
|
|