File size: 4,320 Bytes
2759894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEncoder, BertLayer
# --- 1. Monarch Low-Level Operations ---
class MonarchUp(nn.Module):
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
super().__init__()
self.n_blocks = n_blocks
self.in_block = d_model // n_blocks
self.out_block_exp = hidden_dim // self.in_block
self.b1 = nn.Parameter(torch.randn(n_blocks, self.in_block, self.in_block) * 0.02)
self.b2 = nn.Parameter(torch.randn(self.in_block, self.out_block_exp, n_blocks) * 0.02)
self.bias = nn.Parameter(torch.zeros(hidden_dim))
def forward(self, x):
B, S, D = x.shape
x = x.view(B, S, self.n_blocks, self.in_block)
x = torch.einsum('bsni,noi->bsno', x, self.b1)
x = x.transpose(-1, -2) # Implicit fusion friendly for Triton later
x = torch.einsum('bsni,noi->bsno', x, self.b2)
return x.reshape(B, S, -1) + self.bias
class MonarchDown(nn.Module):
def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
super().__init__()
self.n_blocks = n_blocks
self.out_block = d_model // n_blocks
self.in_block_exp = hidden_dim // self.out_block
self.b1 = nn.Parameter(torch.randn(self.out_block, n_blocks, self.in_block_exp) * 0.02)
self.b2 = nn.Parameter(torch.randn(n_blocks, self.out_block, self.out_block) * 0.02)
self.bias = nn.Parameter(torch.zeros(d_model))
def forward(self, x):
B, S, D = x.shape
x = x.view(B, S, self.out_block, self.in_block_exp)
x = torch.einsum('bsni,noi->bsno', x, self.b1)
x = x.transpose(-1, -2)
x = torch.einsum('bsni,noi->bsno', x, self.b2)
return x.reshape(B, S, -1) + self.bias
class MonarchFFN(nn.Module):
def __init__(self, d_model, hidden_dim, groups, act_fn):
super().__init__()
self.monarch_up = MonarchUp(d_model, hidden_dim, n_blocks=groups)
self.act = act_fn
self.monarch_down = MonarchDown(d_model, hidden_dim, n_blocks=groups)
def forward(self, x):
x = self.monarch_up(x)
x = self.act(x)
x = self.monarch_down(x)
return x
class FFNWrapper(nn.Module):
def __init__(self, new_ffn_module):
super().__init__()
self.ffn = new_ffn_module
def forward(self, hidden_states):
return self.ffn(hidden_states)
# --- 2. The Model Architecture ---
class MonarchBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# 1. Load Standard BERT
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 2. Inject Monarch Layers based on Config
# This reconstructs the architecture exactly as it was during distillation
monarch_start_layer = getattr(config, "monarch_start_layer", 0)
n_groups = getattr(config, "monarch_groups", 16)
bert_layers = self.bert.encoder.layer
# Backward replacement logic (11 -> start_layer)
for i in range(11, monarch_start_layer - 1, -1):
d_model = config.hidden_size
h_dim = config.intermediate_size
# Create Monarch Module
monarch_ffn = MonarchFFN(d_model, h_dim, n_groups, nn.GELU())
# Surgery
bert_layers[i].intermediate = FFNWrapper(monarch_ffn)
bert_layers[i].output.dense = nn.Identity()
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return type('SequenceClassifierOutput', (object,), {'loss': loss, 'logits': logits})()
|