jrawa's picture
Upload folder using huggingface_hub
7a028db verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
# -------------------------------
# 1. Model Definition
# -------------------------------
class FakeBERT(nn.Module):
def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2):
super().__init__()
# Base transformer model (AutoModel is future-proof)
self.bert = AutoModel.from_pretrained(model_name)
hidden = self.bert.config.hidden_size
out_channels = 128
# Parallel 1D convs across token dimension (in_channels = hidden)
self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same')
self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same')
self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same')
# Post-concatenation conv layers operate on concatenated channels
self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1)
self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
# We'll apply a final adaptive pooling to length 1 -> deterministic flattened size = out_channels
self.final_pool_size = 1
# Fully connected layers (in_features = out_channels after final global pool)
self.fc1 = nn.Linear(out_channels, 128)
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(128, num_classes)
self.relu = nn.ReLU()
# Whether the backbone expects token_type_ids (some models like bert do, distilbert does not)
# Use model config if available; fallback: assume not present
self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None
def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None):
"""
Handles both short and long sequences by chunking if needed.
Returns last_hidden_state shaped (B, seq_len, hidden)
"""
B, L = input_ids.size()
max_len = getattr(self.bert.config, "max_position_embeddings", 512)
# Helper to build kwargs robustly
def build_kwargs(ii, am=None, tt=None):
kwargs = {"input_ids": ii}
if am is not None:
kwargs["attention_mask"] = am
if tt is not None and self._accepts_token_type_ids:
kwargs["token_type_ids"] = tt
return kwargs
# --- Fast path: short sequence ---
if L <= max_len:
kwargs = build_kwargs(input_ids, attention_mask, token_type_ids)
return self.bert(**kwargs).last_hidden_state # (B, seq_len, hidden)
# --- Long input: chunk and recombine ---
chunks, masks, types = [], [], []
for start in range(0, L, max_len):
end = min(start + max_len, L)
chunks.append(input_ids[:, start:end])
if attention_mask is not None:
masks.append(attention_mask[:, start:end])
if token_type_ids is not None:
types.append(token_type_ids[:, start:end])
# Pad chunks to equal length (minimal padding)
chunk_lens = [c.size(1) for c in chunks]
max_chunk_len = max(chunk_lens)
device = input_ids.device
padded_chunks = []
padded_masks = [] if masks else None
padded_types = [] if types else None
for i, c in enumerate(chunks):
pad_len = max_chunk_len - c.size(1)
if pad_len > 0:
pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device)
c = torch.cat([c, pad_ids], dim=1)
padded_chunks.append(c)
if masks:
m = masks[i]
if pad_len > 0:
pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device)
m = torch.cat([m, pad_m], dim=1)
padded_masks.append(m)
if types:
t = types[i]
if pad_len > 0:
pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device)
t = torch.cat([t, pad_t], dim=1)
padded_types.append(t)
# Batch all chunks together for a single forward pass
input_chunks = torch.cat(padded_chunks, dim=0) # (B * n_chunks, chunk_len)
attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None
token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None
kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks)
x_all = self.bert(**kwargs).last_hidden_state # (B * n_chunks, chunk_len, hidden)
# recombine: x_all stacked as [chunk0_batch; chunk1_batch; ...], so recombine per original batch
n_chunks = len(chunks)
# split x_all into list of length n_chunks each of shape (B, chunk_len, hidden)
split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0)
# concatenate along token dimension
x = torch.cat(list(split), dim=1) # (B, total_seq_len, hidden)
return x
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
# Transformer forward (handles chunking)
x = self._forward_transformer(input_ids, attention_mask, token_type_ids) # (B, seq_len, hidden)
# --- Convolutional feature extraction ---
x = x.transpose(1, 2) # (B, hidden, seq_len)
seq_len = x.size(2)
# Parallel conv + relu
c1 = self.relu(self.conv1(x))
c2 = self.relu(self.conv2(x))
c3 = self.relu(self.conv3(x))
# Ensure same seq_len for concat (padding in convs keeps lengths equal due to padding)
x = torch.cat([c1, c2, c3], dim=1) # (B, 3*out_channels, seq_len)
# Post convs
x = self.relu(self.conv_post1(x))
x = self.relu(self.conv_post2(x))
# Final adaptive global pooling to fixed length 1
x = F.adaptive_max_pool1d(x, self.final_pool_size) # (B, out_channels, 1)
x = x.squeeze(-1) # (B, out_channels)
# Fully connected head
x = self.relu(self.fc1(x))
x = self.dropout(x)
logits = self.fc2(x)
return logits