import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel # ------------------------------- # 1. Model Definition # ------------------------------- class FakeBERT(nn.Module): def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2): super().__init__() # Base transformer model (AutoModel is future-proof) self.bert = AutoModel.from_pretrained(model_name) hidden = self.bert.config.hidden_size out_channels = 128 # Parallel 1D convs across token dimension (in_channels = hidden) self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same') self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same') self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same') # Post-concatenation conv layers operate on concatenated channels self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1) self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) # We'll apply a final adaptive pooling to length 1 -> deterministic flattened size = out_channels self.final_pool_size = 1 # Fully connected layers (in_features = out_channels after final global pool) self.fc1 = nn.Linear(out_channels, 128) self.dropout = nn.Dropout(dropout) self.fc2 = nn.Linear(128, num_classes) self.relu = nn.ReLU() # Whether the backbone expects token_type_ids (some models like bert do, distilbert does not) # Use model config if available; fallback: assume not present self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None): """ Handles both short and long sequences by chunking if needed. Returns last_hidden_state shaped (B, seq_len, hidden) """ B, L = input_ids.size() max_len = getattr(self.bert.config, "max_position_embeddings", 512) # Helper to build kwargs robustly def build_kwargs(ii, am=None, tt=None): kwargs = {"input_ids": ii} if am is not None: kwargs["attention_mask"] = am if tt is not None and self._accepts_token_type_ids: kwargs["token_type_ids"] = tt return kwargs # --- Fast path: short sequence --- if L <= max_len: kwargs = build_kwargs(input_ids, attention_mask, token_type_ids) return self.bert(**kwargs).last_hidden_state # (B, seq_len, hidden) # --- Long input: chunk and recombine --- chunks, masks, types = [], [], [] for start in range(0, L, max_len): end = min(start + max_len, L) chunks.append(input_ids[:, start:end]) if attention_mask is not None: masks.append(attention_mask[:, start:end]) if token_type_ids is not None: types.append(token_type_ids[:, start:end]) # Pad chunks to equal length (minimal padding) chunk_lens = [c.size(1) for c in chunks] max_chunk_len = max(chunk_lens) device = input_ids.device padded_chunks = [] padded_masks = [] if masks else None padded_types = [] if types else None for i, c in enumerate(chunks): pad_len = max_chunk_len - c.size(1) if pad_len > 0: pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device) c = torch.cat([c, pad_ids], dim=1) padded_chunks.append(c) if masks: m = masks[i] if pad_len > 0: pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device) m = torch.cat([m, pad_m], dim=1) padded_masks.append(m) if types: t = types[i] if pad_len > 0: pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device) t = torch.cat([t, pad_t], dim=1) padded_types.append(t) # Batch all chunks together for a single forward pass input_chunks = torch.cat(padded_chunks, dim=0) # (B * n_chunks, chunk_len) attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks) x_all = self.bert(**kwargs).last_hidden_state # (B * n_chunks, chunk_len, hidden) # recombine: x_all stacked as [chunk0_batch; chunk1_batch; ...], so recombine per original batch n_chunks = len(chunks) # split x_all into list of length n_chunks each of shape (B, chunk_len, hidden) split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0) # concatenate along token dimension x = torch.cat(list(split), dim=1) # (B, total_seq_len, hidden) return x def forward(self, input_ids, attention_mask=None, token_type_ids=None): # Transformer forward (handles chunking) x = self._forward_transformer(input_ids, attention_mask, token_type_ids) # (B, seq_len, hidden) # --- Convolutional feature extraction --- x = x.transpose(1, 2) # (B, hidden, seq_len) seq_len = x.size(2) # Parallel conv + relu c1 = self.relu(self.conv1(x)) c2 = self.relu(self.conv2(x)) c3 = self.relu(self.conv3(x)) # Ensure same seq_len for concat (padding in convs keeps lengths equal due to padding) x = torch.cat([c1, c2, c3], dim=1) # (B, 3*out_channels, seq_len) # Post convs x = self.relu(self.conv_post1(x)) x = self.relu(self.conv_post2(x)) # Final adaptive global pooling to fixed length 1 x = F.adaptive_max_pool1d(x, self.final_pool_size) # (B, out_channels, 1) x = x.squeeze(-1) # (B, out_channels) # Fully connected head x = self.relu(self.fc1(x)) x = self.dropout(x) logits = self.fc2(x) return logits