File size: 6,296 Bytes
7a028db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel


# -------------------------------
# 1. Model Definition
# -------------------------------
class FakeBERT(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2):
        super().__init__()

        # Base transformer model (AutoModel is future-proof)
        self.bert = AutoModel.from_pretrained(model_name)
        hidden = self.bert.config.hidden_size
        out_channels = 128

        # Parallel 1D convs across token dimension (in_channels = hidden)
        self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same')
        self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same')
        self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same')

        # Post-concatenation conv layers operate on concatenated channels
        self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1)
        self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)

        # We'll apply a final adaptive pooling to length 1 -> deterministic flattened size = out_channels
        self.final_pool_size = 1

        # Fully connected layers (in_features = out_channels after final global pool)
        self.fc1 = nn.Linear(out_channels, 128)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

        # Whether the backbone expects token_type_ids (some models like bert do, distilbert does not)
        # Use model config if available; fallback: assume not present
        self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None

    def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None):
        """
        Handles both short and long sequences by chunking if needed.
        Returns last_hidden_state shaped (B, seq_len, hidden)
        """
        B, L = input_ids.size()
        max_len = getattr(self.bert.config, "max_position_embeddings", 512)

        # Helper to build kwargs robustly
        def build_kwargs(ii, am=None, tt=None):
            kwargs = {"input_ids": ii}
            if am is not None:
                kwargs["attention_mask"] = am
            if tt is not None and self._accepts_token_type_ids:
                kwargs["token_type_ids"] = tt
            return kwargs

        # --- Fast path: short sequence ---
        if L <= max_len:
            kwargs = build_kwargs(input_ids, attention_mask, token_type_ids)
            return self.bert(**kwargs).last_hidden_state  # (B, seq_len, hidden)

        # --- Long input: chunk and recombine ---
        chunks, masks, types = [], [], []
        for start in range(0, L, max_len):
            end = min(start + max_len, L)
            chunks.append(input_ids[:, start:end])
            if attention_mask is not None:
                masks.append(attention_mask[:, start:end])
            if token_type_ids is not None:
                types.append(token_type_ids[:, start:end])

        # Pad chunks to equal length (minimal padding)
        chunk_lens = [c.size(1) for c in chunks]
        max_chunk_len = max(chunk_lens)
        device = input_ids.device

        padded_chunks = []
        padded_masks = [] if masks else None
        padded_types = [] if types else None

        for i, c in enumerate(chunks):
            pad_len = max_chunk_len - c.size(1)
            if pad_len > 0:
                pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device)
                c = torch.cat([c, pad_ids], dim=1)
            padded_chunks.append(c)

            if masks:
                m = masks[i]
                if pad_len > 0:
                    pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device)
                    m = torch.cat([m, pad_m], dim=1)
                padded_masks.append(m)

            if types:
                t = types[i]
                if pad_len > 0:
                    pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device)
                    t = torch.cat([t, pad_t], dim=1)
                padded_types.append(t)

        # Batch all chunks together for a single forward pass
        input_chunks = torch.cat(padded_chunks, dim=0)  # (B * n_chunks, chunk_len)
        attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None
        token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None

        kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks)
        x_all = self.bert(**kwargs).last_hidden_state  # (B * n_chunks, chunk_len, hidden)

        # recombine: x_all stacked as [chunk0_batch; chunk1_batch; ...], so recombine per original batch
        n_chunks = len(chunks)
        # split x_all into list of length n_chunks each of shape (B, chunk_len, hidden)
        split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0)
        # concatenate along token dimension
        x = torch.cat(list(split), dim=1)  # (B, total_seq_len, hidden)
        return x

    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        # Transformer forward (handles chunking)
        x = self._forward_transformer(input_ids, attention_mask, token_type_ids)  # (B, seq_len, hidden)

        # --- Convolutional feature extraction ---
        x = x.transpose(1, 2)  # (B, hidden, seq_len)
        seq_len = x.size(2)

        # Parallel conv + relu
        c1 = self.relu(self.conv1(x))
        c2 = self.relu(self.conv2(x))
        c3 = self.relu(self.conv3(x))

        # Ensure same seq_len for concat (padding in convs keeps lengths equal due to padding)
        x = torch.cat([c1, c2, c3], dim=1)  # (B, 3*out_channels, seq_len)

        # Post convs
        x = self.relu(self.conv_post1(x))
        x = self.relu(self.conv_post2(x))

        # Final adaptive global pooling to fixed length 1
        x = F.adaptive_max_pool1d(x, self.final_pool_size)  # (B, out_channels, 1)
        x = x.squeeze(-1)  # (B, out_channels)

        # Fully connected head
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits