File size: 4,320 Bytes
2759894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
from transformers.models.bert.modeling_bert import BertEncoder, BertLayer

# --- 1. Monarch Low-Level Operations ---
class MonarchUp(nn.Module):
    def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
        super().__init__()
        self.n_blocks = n_blocks
        self.in_block = d_model // n_blocks
        self.out_block_exp = hidden_dim // self.in_block
        
        self.b1 = nn.Parameter(torch.randn(n_blocks, self.in_block, self.in_block) * 0.02)
        self.b2 = nn.Parameter(torch.randn(self.in_block, self.out_block_exp, n_blocks) * 0.02)
        self.bias = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x):
        B, S, D = x.shape
        x = x.view(B, S, self.n_blocks, self.in_block)
        x = torch.einsum('bsni,noi->bsno', x, self.b1)
        x = x.transpose(-1, -2) # Implicit fusion friendly for Triton later
        x = torch.einsum('bsni,noi->bsno', x, self.b2)
        return x.reshape(B, S, -1) + self.bias

class MonarchDown(nn.Module):
    def __init__(self, d_model=768, hidden_dim=3072, n_blocks=16):
        super().__init__()
        self.n_blocks = n_blocks
        self.out_block = d_model // n_blocks
        self.in_block_exp = hidden_dim // self.out_block
        
        self.b1 = nn.Parameter(torch.randn(self.out_block, n_blocks, self.in_block_exp) * 0.02)
        self.b2 = nn.Parameter(torch.randn(n_blocks, self.out_block, self.out_block) * 0.02)
        self.bias = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        B, S, D = x.shape
        x = x.view(B, S, self.out_block, self.in_block_exp)
        x = torch.einsum('bsni,noi->bsno', x, self.b1)
        x = x.transpose(-1, -2)
        x = torch.einsum('bsni,noi->bsno', x, self.b2)
        return x.reshape(B, S, -1) + self.bias

class MonarchFFN(nn.Module):
    def __init__(self, d_model, hidden_dim, groups, act_fn):
        super().__init__()
        self.monarch_up = MonarchUp(d_model, hidden_dim, n_blocks=groups)
        self.act = act_fn
        self.monarch_down = MonarchDown(d_model, hidden_dim, n_blocks=groups)

    def forward(self, x):
        x = self.monarch_up(x)
        x = self.act(x)
        x = self.monarch_down(x)
        return x

class FFNWrapper(nn.Module):
    def __init__(self, new_ffn_module):
        super().__init__()
        self.ffn = new_ffn_module
    def forward(self, hidden_states):
        return self.ffn(hidden_states)

# --- 2. The Model Architecture ---
class MonarchBertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        
        # 1. Load Standard BERT
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
        # 2. Inject Monarch Layers based on Config
        # This reconstructs the architecture exactly as it was during distillation
        monarch_start_layer = getattr(config, "monarch_start_layer", 0) 
        n_groups = getattr(config, "monarch_groups", 16)
        
        bert_layers = self.bert.encoder.layer
        # Backward replacement logic (11 -> start_layer)
        for i in range(11, monarch_start_layer - 1, -1):
            d_model = config.hidden_size
            h_dim = config.intermediate_size
            
            # Create Monarch Module
            monarch_ffn = MonarchFFN(d_model, h_dim, n_groups, nn.GELU())
            
            # Surgery
            bert_layers[i].intermediate = FFNWrapper(monarch_ffn)
            bert_layers[i].output.dense = nn.Identity()

        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return type('SequenceClassifierOutput', (object,), {'loss': loss, 'logits': logits})()