File size: 6,296 Bytes
7a028db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
# -------------------------------
# 1. Model Definition
# -------------------------------
class FakeBERT(nn.Module):
def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2):
super().__init__()
# Base transformer model (AutoModel is future-proof)
self.bert = AutoModel.from_pretrained(model_name)
hidden = self.bert.config.hidden_size
out_channels = 128
# Parallel 1D convs across token dimension (in_channels = hidden)
self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same')
self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same')
self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same')
# Post-concatenation conv layers operate on concatenated channels
self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1)
self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
# We'll apply a final adaptive pooling to length 1 -> deterministic flattened size = out_channels
self.final_pool_size = 1
# Fully connected layers (in_features = out_channels after final global pool)
self.fc1 = nn.Linear(out_channels, 128)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(128, num_classes)
self.relu = nn.ReLU()
# Whether the backbone expects token_type_ids (some models like bert do, distilbert does not)
# Use model config if available; fallback: assume not present
self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None
def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None):
"""
Handles both short and long sequences by chunking if needed.
Returns last_hidden_state shaped (B, seq_len, hidden)
"""
B, L = input_ids.size()
max_len = getattr(self.bert.config, "max_position_embeddings", 512)
# Helper to build kwargs robustly
def build_kwargs(ii, am=None, tt=None):
kwargs = {"input_ids": ii}
if am is not None:
kwargs["attention_mask"] = am
if tt is not None and self._accepts_token_type_ids:
kwargs["token_type_ids"] = tt
return kwargs
# --- Fast path: short sequence ---
if L <= max_len:
kwargs = build_kwargs(input_ids, attention_mask, token_type_ids)
return self.bert(**kwargs).last_hidden_state # (B, seq_len, hidden)
# --- Long input: chunk and recombine ---
chunks, masks, types = [], [], []
for start in range(0, L, max_len):
end = min(start + max_len, L)
chunks.append(input_ids[:, start:end])
if attention_mask is not None:
masks.append(attention_mask[:, start:end])
if token_type_ids is not None:
types.append(token_type_ids[:, start:end])
# Pad chunks to equal length (minimal padding)
chunk_lens = [c.size(1) for c in chunks]
max_chunk_len = max(chunk_lens)
device = input_ids.device
padded_chunks = []
padded_masks = [] if masks else None
padded_types = [] if types else None
for i, c in enumerate(chunks):
pad_len = max_chunk_len - c.size(1)
if pad_len > 0:
pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device)
c = torch.cat([c, pad_ids], dim=1)
padded_chunks.append(c)
if masks:
m = masks[i]
if pad_len > 0:
pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device)
m = torch.cat([m, pad_m], dim=1)
padded_masks.append(m)
if types:
t = types[i]
if pad_len > 0:
pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device)
t = torch.cat([t, pad_t], dim=1)
padded_types.append(t)
# Batch all chunks together for a single forward pass
input_chunks = torch.cat(padded_chunks, dim=0) # (B * n_chunks, chunk_len)
attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None
token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None
kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks)
x_all = self.bert(**kwargs).last_hidden_state # (B * n_chunks, chunk_len, hidden)
# recombine: x_all stacked as [chunk0_batch; chunk1_batch; ...], so recombine per original batch
n_chunks = len(chunks)
# split x_all into list of length n_chunks each of shape (B, chunk_len, hidden)
split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0)
# concatenate along token dimension
x = torch.cat(list(split), dim=1) # (B, total_seq_len, hidden)
return x
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
# Transformer forward (handles chunking)
x = self._forward_transformer(input_ids, attention_mask, token_type_ids) # (B, seq_len, hidden)
# --- Convolutional feature extraction ---
x = x.transpose(1, 2) # (B, hidden, seq_len)
seq_len = x.size(2)
# Parallel conv + relu
c1 = self.relu(self.conv1(x))
c2 = self.relu(self.conv2(x))
c3 = self.relu(self.conv3(x))
# Ensure same seq_len for concat (padding in convs keeps lengths equal due to padding)
x = torch.cat([c1, c2, c3], dim=1) # (B, 3*out_channels, seq_len)
# Post convs
x = self.relu(self.conv_post1(x))
x = self.relu(self.conv_post2(x))
# Final adaptive global pooling to fixed length 1
x = F.adaptive_max_pool1d(x, self.final_pool_size) # (B, out_channels, 1)
x = x.squeeze(-1) # (B, out_channels)
# Fully connected head
x = self.relu(self.fc1(x))
x = self.dropout(x)
logits = self.fc2(x)
return logits
|