jrawa commited on
Commit
7a028db
·
verified ·
1 Parent(s): 89d2a9a

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. distilbert_best.pth +3 -0
  2. load.py +7 -0
  3. model.py +149 -0
distilbert_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6a11bf0b621c366ec36342d15e17964c3bf060ebf0ab7de53012e909baf89ae
3
+ size 271071638
load.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import FakeBERT
3
+
4
+ model = FakeBERT(model_name=MODEL_NAME, num_classes=NUM_CLASSES).to(DEVICE)
5
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
6
+ model.load_state_dict(state_dict)
7
+
model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel
5
+
6
+
7
+ # -------------------------------
8
+ # 1. Model Definition
9
+ # -------------------------------
10
+ class FakeBERT(nn.Module):
11
+ def __init__(self, model_name="bert-base-uncased", num_classes=3, dropout=0.2):
12
+ super().__init__()
13
+
14
+ # Base transformer model (AutoModel is future-proof)
15
+ self.bert = AutoModel.from_pretrained(model_name)
16
+ hidden = self.bert.config.hidden_size
17
+ out_channels = 128
18
+
19
+ # Parallel 1D convs across token dimension (in_channels = hidden)
20
+ self.conv1 = nn.Conv1d(hidden, out_channels, kernel_size=3, padding='same')
21
+ self.conv2 = nn.Conv1d(hidden, out_channels, kernel_size=4, padding='same')
22
+ self.conv3 = nn.Conv1d(hidden, out_channels, kernel_size=5, padding='same')
23
+
24
+ # Post-concatenation conv layers operate on concatenated channels
25
+ self.conv_post1 = nn.Conv1d(out_channels * 3, out_channels, kernel_size=3, padding=1)
26
+ self.conv_post2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
27
+
28
+ # We'll apply a final adaptive pooling to length 1 -> deterministic flattened size = out_channels
29
+ self.final_pool_size = 1
30
+
31
+ # Fully connected layers (in_features = out_channels after final global pool)
32
+ self.fc1 = nn.Linear(out_channels, 128)
33
+ self.dropout = nn.Dropout(dropout)
34
+ self.fc2 = nn.Linear(128, num_classes)
35
+ self.relu = nn.ReLU()
36
+
37
+ # Whether the backbone expects token_type_ids (some models like bert do, distilbert does not)
38
+ # Use model config if available; fallback: assume not present
39
+ self._accepts_token_type_ids = getattr(self.bert.config, "type_vocab_size", None) is not None
40
+
41
+ def _forward_transformer(self, input_ids, attention_mask=None, token_type_ids=None):
42
+ """
43
+ Handles both short and long sequences by chunking if needed.
44
+ Returns last_hidden_state shaped (B, seq_len, hidden)
45
+ """
46
+ B, L = input_ids.size()
47
+ max_len = getattr(self.bert.config, "max_position_embeddings", 512)
48
+
49
+ # Helper to build kwargs robustly
50
+ def build_kwargs(ii, am=None, tt=None):
51
+ kwargs = {"input_ids": ii}
52
+ if am is not None:
53
+ kwargs["attention_mask"] = am
54
+ if tt is not None and self._accepts_token_type_ids:
55
+ kwargs["token_type_ids"] = tt
56
+ return kwargs
57
+
58
+ # --- Fast path: short sequence ---
59
+ if L <= max_len:
60
+ kwargs = build_kwargs(input_ids, attention_mask, token_type_ids)
61
+ return self.bert(**kwargs).last_hidden_state # (B, seq_len, hidden)
62
+
63
+ # --- Long input: chunk and recombine ---
64
+ chunks, masks, types = [], [], []
65
+ for start in range(0, L, max_len):
66
+ end = min(start + max_len, L)
67
+ chunks.append(input_ids[:, start:end])
68
+ if attention_mask is not None:
69
+ masks.append(attention_mask[:, start:end])
70
+ if token_type_ids is not None:
71
+ types.append(token_type_ids[:, start:end])
72
+
73
+ # Pad chunks to equal length (minimal padding)
74
+ chunk_lens = [c.size(1) for c in chunks]
75
+ max_chunk_len = max(chunk_lens)
76
+ device = input_ids.device
77
+
78
+ padded_chunks = []
79
+ padded_masks = [] if masks else None
80
+ padded_types = [] if types else None
81
+
82
+ for i, c in enumerate(chunks):
83
+ pad_len = max_chunk_len - c.size(1)
84
+ if pad_len > 0:
85
+ pad_ids = torch.zeros(B, pad_len, dtype=c.dtype, device=device)
86
+ c = torch.cat([c, pad_ids], dim=1)
87
+ padded_chunks.append(c)
88
+
89
+ if masks:
90
+ m = masks[i]
91
+ if pad_len > 0:
92
+ pad_m = torch.zeros(B, pad_len, dtype=m.dtype, device=device)
93
+ m = torch.cat([m, pad_m], dim=1)
94
+ padded_masks.append(m)
95
+
96
+ if types:
97
+ t = types[i]
98
+ if pad_len > 0:
99
+ pad_t = torch.zeros(B, pad_len, dtype=t.dtype, device=device)
100
+ t = torch.cat([t, pad_t], dim=1)
101
+ padded_types.append(t)
102
+
103
+ # Batch all chunks together for a single forward pass
104
+ input_chunks = torch.cat(padded_chunks, dim=0) # (B * n_chunks, chunk_len)
105
+ attention_chunks = torch.cat(padded_masks, dim=0) if padded_masks is not None else None
106
+ token_chunks = torch.cat(padded_types, dim=0) if padded_types is not None else None
107
+
108
+ kwargs = build_kwargs(input_chunks, attention_chunks, token_chunks)
109
+ x_all = self.bert(**kwargs).last_hidden_state # (B * n_chunks, chunk_len, hidden)
110
+
111
+ # recombine: x_all stacked as [chunk0_batch; chunk1_batch; ...], so recombine per original batch
112
+ n_chunks = len(chunks)
113
+ # split x_all into list of length n_chunks each of shape (B, chunk_len, hidden)
114
+ split = torch.split(x_all, input_chunks.size(0) // n_chunks, dim=0)
115
+ # concatenate along token dimension
116
+ x = torch.cat(list(split), dim=1) # (B, total_seq_len, hidden)
117
+ return x
118
+
119
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None):
120
+ # Transformer forward (handles chunking)
121
+ x = self._forward_transformer(input_ids, attention_mask, token_type_ids) # (B, seq_len, hidden)
122
+
123
+ # --- Convolutional feature extraction ---
124
+ x = x.transpose(1, 2) # (B, hidden, seq_len)
125
+ seq_len = x.size(2)
126
+
127
+ # Parallel conv + relu
128
+ c1 = self.relu(self.conv1(x))
129
+ c2 = self.relu(self.conv2(x))
130
+ c3 = self.relu(self.conv3(x))
131
+
132
+ # Ensure same seq_len for concat (padding in convs keeps lengths equal due to padding)
133
+ x = torch.cat([c1, c2, c3], dim=1) # (B, 3*out_channels, seq_len)
134
+
135
+ # Post convs
136
+ x = self.relu(self.conv_post1(x))
137
+ x = self.relu(self.conv_post2(x))
138
+
139
+ # Final adaptive global pooling to fixed length 1
140
+ x = F.adaptive_max_pool1d(x, self.final_pool_size) # (B, out_channels, 1)
141
+ x = x.squeeze(-1) # (B, out_channels)
142
+
143
+ # Fully connected head
144
+ x = self.relu(self.fc1(x))
145
+ x = self.dropout(x)
146
+ logits = self.fc2(x)
147
+
148
+ return logits
149
+