|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeBERT(nn.Module): |
|
|
def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.bert = AutoModel.from_pretrained(model_name) |
|
|
hidden = self.bert.config.hidden_size |
|
|
out_channels = 128 |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same') |
|
|
self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same') |
|
|
self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same') |
|
|
|
|
|
|
|
|
self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1) |
|
|
self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
|
self.final_pool_size = 1 |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(out_channels, 128) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.fc2 = nn.Linear(128, num_classes) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
|
|
|
|
|
|
self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None |
|
|
|
|
|
def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None): |
|
|
""" |
|
|
Handles both short and long sequences by chunking if needed. |
|
|
Returns last_hidden_state shaped (B, seq_len, hidden) |
|
|
""" |
|
|
B, L = input_ids.size() |
|
|
max_len = getattr(self.bert.config, "max_position_embeddings", 512) |
|
|
|
|
|
|
|
|
def build_kwargs(ii, am=None, tt=None): |
|
|
kwargs = {"input_ids": ii} |
|
|
if am is not None: |
|
|
kwargs["attention_mask"] = am |
|
|
if tt is not None and self._accepts_token_type_ids: |
|
|
kwargs["token_type_ids"] = tt |
|
|
return kwargs |
|
|
|
|
|
|
|
|
if L <= max_len: |
|
|
kwargs = build_kwargs(input_ids, attention_mask, token_type_ids) |
|
|
return self.bert(**kwargs).last_hidden_state |
|
|
|
|
|
|
|
|
chunks, masks, types = [], [], [] |
|
|
for start in range(0, L, max_len): |
|
|
end = min(start + max_len, L) |
|
|
chunks.append(input_ids[:, start:end]) |
|
|
if attention_mask is not None: |
|
|
masks.append(attention_mask[:, start:end]) |
|
|
if token_type_ids is not None: |
|
|
types.append(token_type_ids[:, start:end]) |
|
|
|
|
|
|
|
|
chunk_lens = [c.size(1) for c in chunks] |
|
|
max_chunk_len = max(chunk_lens) |
|
|
device = input_ids.device |
|
|
|
|
|
padded_chunks = [] |
|
|
padded_masks = [] if masks else None |
|
|
padded_types = [] if types else None |
|
|
|
|
|
for i, c in enumerate(chunks): |
|
|
pad_len = max_chunk_len - c.size(1) |
|
|
if pad_len > 0: |
|
|
pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device) |
|
|
c = torch.cat([c, pad_ids], dim=1) |
|
|
padded_chunks.append(c) |
|
|
|
|
|
if masks: |
|
|
m = masks[i] |
|
|
if pad_len > 0: |
|
|
pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device) |
|
|
m = torch.cat([m, pad_m], dim=1) |
|
|
padded_masks.append(m) |
|
|
|
|
|
if types: |
|
|
t = types[i] |
|
|
if pad_len > 0: |
|
|
pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device) |
|
|
t = torch.cat([t, pad_t], dim=1) |
|
|
padded_types.append(t) |
|
|
|
|
|
|
|
|
input_chunks = torch.cat(padded_chunks, dim=0) |
|
|
attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None |
|
|
token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None |
|
|
|
|
|
kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks) |
|
|
x_all = self.bert(**kwargs).last_hidden_state |
|
|
|
|
|
|
|
|
n_chunks = len(chunks) |
|
|
|
|
|
split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0) |
|
|
|
|
|
x = torch.cat(list(split), dim=1) |
|
|
return x |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
|
|
|
|
|
x = self._forward_transformer(input_ids, attention_mask, token_type_ids) |
|
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
seq_len = x.size(2) |
|
|
|
|
|
|
|
|
c1 = self.relu(self.conv1(x)) |
|
|
c2 = self.relu(self.conv2(x)) |
|
|
c3 = self.relu(self.conv3(x)) |
|
|
|
|
|
|
|
|
x = torch.cat([c1, c2, c3], dim=1) |
|
|
|
|
|
|
|
|
x = self.relu(self.conv_post1(x)) |
|
|
x = self.relu(self.conv_post2(x)) |
|
|
|
|
|
|
|
|
x = F.adaptive_max_pool1d(x, self.final_pool_size) |
|
|
x = x.squeeze(-1) |
|
|
|
|
|
|
|
|
x = self.relu(self.fc1(x)) |
|
|
x = self.dropout(x) |
|
|
logits = self.fc2(x) |
|
|
|
|
|
return logits |
|
|
|
|
|
|