aagamjtdev commited on
Commit
f1d3547
·
1 Parent(s): 9b410c3

training script

Browse files
Files changed (1) hide show
  1. train_model.py +777 -0
train_model.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import re
4
+ import json
5
+ import pickle
6
+ import argparse
7
+ from collections import Counter
8
+ from typing import List, Tuple, Dict, Any
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+ try:
17
+ from TorchCRF import CRF
18
+ except ImportError:
19
+ print("Error: The 'TorchCRF' library is required. Please install it using 'pip install torch-crf'.")
20
+ exit()
21
+
22
+ # ========== CONFIG ==========
23
+ # Using the user's saved path information for DATA_DIR and model/vocab file names
24
+ DATA_DIR = "output_data"
25
+ MODEL_FILE = "model_enhanced.pt" # Using user's saved model filename
26
+ VOCAB_FILE = "vocabs_enhanced.pkl" # Using user's saved vocab filename
27
+ CHECKPOINT_FILE = "checkpoint_enhanced.pt" # New file for full checkpoint (incl. optimizer, epoch, etc.)
28
+
29
+ os.makedirs(DATA_DIR, exist_ok=True)
30
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ MAX_CHAR_LEN = 16
32
+ EMBED_DIM = 128 # Increased from 100
33
+ CHAR_EMBED_DIM = 50 # Increased from 30
34
+ CHAR_CNN_OUT = 50 # Increased from 30
35
+ BBOX_DIM = 128 # Increased from 100
36
+ HIDDEN_SIZE = 768 # Increased from 512 to match LayoutLM dimension
37
+ BATCH_SIZE = 8
38
+ EPOCHS = 10 # Increased from 30
39
+ LR = 5e-4 # Decreased from 1e-3 for more stable training
40
+ BBOX_NORM_CONSTANT = 1000.0
41
+ CHUNK_SIZE = 450
42
+
43
+ # Enhanced feature dimensions
44
+ SPATIAL_FEATURE_DIM = 64 # Increased from 32
45
+ POSITIONAL_DIM = 128 # New: For learnable positional embeddings
46
+
47
+ # ========== LABELS ==========
48
+ LABELS = [
49
+ "O",
50
+ "B-QUESTION", "I-QUESTION",
51
+ "B-OPTION", "I-OPTION",
52
+ "B-ANSWER", "I-ANSWER",
53
+ "B-IMAGE", "I-IMAGE",
54
+ "B-SECTION HEADING", "I-SECTION HEADING",
55
+ "B-PASSAGE", "I-PASSAGE"
56
+ ]
57
+ LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
58
+ IDX2LABEL = {i: l for l, i in LABEL2IDX.items()}
59
+
60
+
61
+ # ========== ENHANCED FEATURE EXTRACTION (PATTERN FUNCTION REMOVED) ==========
62
+
63
+ def extract_spatial_features(tokens: List[Dict], idx: int) -> List[float]:
64
+ """Enhanced spatial features with relative positioning."""
65
+ current = tokens[idx]
66
+ features = []
67
+
68
+ # Vertical spacing with next token (look-ahead)
69
+ if idx < len(tokens) - 1:
70
+ next_tok = tokens[idx + 1]
71
+ forward_gap = next_tok['y0'] - current['y1']
72
+ features.append(min(forward_gap / 100.0, 1.0))
73
+ else:
74
+ features.append(0.0)
75
+
76
+ # Vertical spacing with previous token
77
+ if idx > 0:
78
+ prev = tokens[idx - 1]
79
+ vertical_gap = current['y0'] - prev['y1']
80
+ features.append(min(vertical_gap / 100.0, 1.0))
81
+ else:
82
+ features.append(0.0)
83
+
84
+ # Horizontal offset (indentation)
85
+ features.append(current['x0'] / BBOX_NORM_CONSTANT)
86
+
87
+ # Token dimensions
88
+ width = current['x1'] - current['x0']
89
+ height = current['y1'] - current['y0']
90
+ features.append(width / BBOX_NORM_CONSTANT)
91
+ features.append(height / BBOX_NORM_CONSTANT)
92
+
93
+ # Position in line
94
+ x_center = (current['x0'] + current['x1']) / 2
95
+ y_center = (current['y0'] + current['y1']) / 2
96
+ features.append(x_center / BBOX_NORM_CONSTANT)
97
+ features.append(y_center / BBOX_NORM_CONSTANT)
98
+
99
+ # Distance from left margin
100
+ features.append(current['x0'] / BBOX_NORM_CONSTANT)
101
+
102
+ # Aspect ratio
103
+ aspect = width / max(height, 1.0)
104
+ features.append(min(aspect / 10.0, 1.0))
105
+
106
+ # Alignment features (detect if aligned with previous/next)
107
+ if idx > 0:
108
+ prev = tokens[idx - 1]
109
+ x_alignment = abs(current['x0'] - prev['x0']) < 5 # Within 5 units
110
+ features.append(float(x_alignment))
111
+ else:
112
+ features.append(0.0)
113
+
114
+ # Area (normalized)
115
+ area = width * height
116
+ features.append(min(area / (BBOX_NORM_CONSTANT ** 2), 1.0))
117
+
118
+ return features
119
+
120
+
121
+ def extract_context_features(tokens: List[Dict], idx: int, window: int = 3) -> Dict[str, Any]:
122
+ """Enhanced context with larger window and more patterns."""
123
+ context_features = []
124
+
125
+ # Previous context
126
+ prev_has_q = 0.0
127
+ prev_has_opt = 0.0
128
+ prev_has_caps = 0.0
129
+ for i in range(max(0, idx - window), idx):
130
+ text = tokens[i]['text'].lower().strip()
131
+ if re.match(r'^q?\.?\d+[.:]', text):
132
+ prev_has_q = 1.0
133
+ if re.match(r'^[a-dA-D][.)]', text):
134
+ prev_has_opt = 1.0
135
+ if tokens[i]['text'].strip().isupper() and len(tokens[i]['text'].strip()) > 2:
136
+ prev_has_caps = 1.0
137
+
138
+ context_features.extend([prev_has_q, prev_has_opt, prev_has_caps])
139
+
140
+ # Next context
141
+ next_has_q = 0.0
142
+ next_has_opt = 0.0
143
+ next_has_caps = 0.0
144
+ for i in range(idx + 1, min(len(tokens), idx + window + 1)):
145
+ text = tokens[i]['text'].lower().strip()
146
+ if re.match(r'^q?\.?\d+[.:]', text):
147
+ next_has_q = 1.0
148
+ if re.match(r'^[a-dA-D][.)]', text):
149
+ next_has_opt = 1.0
150
+ if tokens[i]['text'].strip().isupper() and len(tokens[i]['text'].strip()) > 2:
151
+ next_has_caps = 1.0
152
+
153
+ context_features.extend([next_has_q, next_has_opt, next_has_caps])
154
+
155
+ # Distance features: how far to next question/option marker
156
+ dist_to_next_q = window + 1
157
+ dist_to_next_opt = window + 1
158
+ for i in range(idx + 1, min(len(tokens), idx + window + 1)):
159
+ text = tokens[i]['text'].lower().strip()
160
+ if re.match(r'^q?\.?\d+[.:]', text) and dist_to_next_q > (i - idx):
161
+ dist_to_next_q = i - idx
162
+ if re.match(r'^[a-dA-D][.)]', text) and dist_to_next_opt > (i - idx):
163
+ dist_to_next_opt = i - idx
164
+
165
+ context_features.append(dist_to_next_q / window)
166
+ context_features.append(dist_to_next_opt / window)
167
+
168
+ return context_features
169
+
170
+
171
+ # ========== Vocab Class ==========
172
+ class Vocab:
173
+ def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
174
+ self.min_freq = min_freq
175
+ self.unk_token = unk_token
176
+ self.pad_token = pad_token
177
+ self.freq = Counter()
178
+ self.itos = []
179
+ self.stoi = {}
180
+
181
+ def add_sentence(self, toks):
182
+ self.freq.update(toks)
183
+
184
+ def build(self):
185
+ items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
186
+ items = [self.pad_token, self.unk_token] + sorted(items)
187
+ self.itos = items
188
+ self.stoi = {s: i for i, s in enumerate(self.itos)}
189
+
190
+ def __len__(self):
191
+ return len(self.itos)
192
+
193
+ def __getitem__(self, token: str) -> int:
194
+ return self.stoi.get(token, self.stoi[self.unk_token])
195
+
196
+ def __getstate__(self):
197
+ return {
198
+ 'min_freq': self.min_freq,
199
+ 'unk_token': self.unk_token,
200
+ 'pad_token': self.pad_token,
201
+ 'itos': self.itos,
202
+ 'stoi': self.stoi,
203
+ }
204
+
205
+ def __setstate__(self, state):
206
+ self.min_freq = state['min_freq']
207
+ self.unk_token = state['unk_token']
208
+ self.pad_token = state['pad_token']
209
+ self.itos = state['itos']
210
+ self.stoi = state['stoi']
211
+ self.freq = Counter()
212
+
213
+
214
+ # ========== Data Loading ==========
215
+ def load_unified_data(unified_json_path: str) -> Tuple[List[Dict[str, Any]], List[List[str]]]:
216
+ """Loads data and extracts enhanced features."""
217
+ if not os.path.exists(unified_json_path):
218
+ raise FileNotFoundError(f"Unified JSON data not found at: {unified_json_path}")
219
+
220
+ with open(unified_json_path, 'r', encoding='utf-8') as f:
221
+ flat_tokens = json.load(f)
222
+
223
+ pages_tokens = []
224
+ labels_per_token = []
225
+
226
+ print(":mag: Extracting spatial and context features (patterns removed)...")
227
+
228
+ for i in tqdm(range(0, len(flat_tokens), CHUNK_SIZE), desc="Processing chunks"):
229
+ chunk = flat_tokens[i:i + CHUNK_SIZE]
230
+ if not chunk: continue
231
+
232
+ tokens_list = []
233
+ for j, t in enumerate(chunk):
234
+ token_dict = {
235
+ "text": t["token"],
236
+ "x0": t["bbox"][0], "y0": t["bbox"][1],
237
+ "x1": t["bbox"][2], "y1": t["bbox"][3],
238
+ "page_no": 0, "block_idx": 0
239
+ }
240
+ # Pattern feature extraction removed
241
+ tokens_list.append(token_dict)
242
+
243
+ for j, token_dict in enumerate(tokens_list):
244
+ token_dict["spatial_features"] = extract_spatial_features(tokens_list, j)
245
+ token_dict["context_features"] = extract_context_features(tokens_list, j, window=3)
246
+
247
+ pages_tokens.append({
248
+ "tokens": tokens_list,
249
+ "width": BBOX_NORM_CONSTANT,
250
+ "height": BBOX_NORM_CONSTANT
251
+ })
252
+ labels_per_token.append([t["label"] for t in chunk])
253
+
254
+ return pages_tokens, labels_per_token
255
+
256
+
257
+ # ========== Dataset ==========
258
+ class MCQTokenDataset(Dataset):
259
+ def __init__(self, pages_tokens, word_vocab, char_vocab, labels_per_token=None):
260
+ self.samples = []
261
+ self.bbox_norm_factor = BBOX_NORM_CONSTANT
262
+
263
+ for page_data in pages_tokens:
264
+ if len(page_data["tokens"]) == 0: continue
265
+ self.samples.append(page_data)
266
+
267
+ self.labels = labels_per_token
268
+ self.word_vocab = word_vocab
269
+ self.char_vocab = char_vocab
270
+
271
+ def __len__(self):
272
+ return len(self.samples)
273
+
274
+ def __getitem__(self, idx):
275
+ page_data = self.samples[idx]
276
+ toks = page_data["tokens"]
277
+
278
+ words = [t["text"] for t in toks]
279
+ word_ids = [self.word_vocab.stoi.get(w, self.word_vocab.stoi[self.word_vocab.unk_token]) for w in words]
280
+
281
+ char_ids = []
282
+ for w in words:
283
+ chs = [self.char_vocab.stoi.get(ch, self.char_vocab.stoi[self.char_vocab.unk_token]) for ch in
284
+ w[:MAX_CHAR_LEN]]
285
+ if len(chs) < MAX_CHAR_LEN:
286
+ chs += [self.char_vocab.stoi[self.char_vocab.pad_token]] * (MAX_CHAR_LEN - len(chs))
287
+ char_ids.append(chs)
288
+
289
+ bboxes = []
290
+ for t in toks:
291
+ normalized_bbox = [
292
+ t["x0"] / self.bbox_norm_factor,
293
+ t["y0"] / self.bbox_norm_factor,
294
+ t["x1"] / self.bbox_norm_factor,
295
+ t["y1"] / self.bbox_norm_factor,
296
+ ]
297
+ bboxes.append(normalized_bbox)
298
+
299
+ # Pattern features removed
300
+ spatial_features = [t["spatial_features"] for t in toks]
301
+ context_features = [t["context_features"] for t in toks]
302
+
303
+ labels = None
304
+ if self.labels:
305
+ lbls = self.labels[idx]
306
+ labels = [LABEL2IDX[l] for l in lbls]
307
+
308
+ return {
309
+ "word_ids": torch.LongTensor(word_ids),
310
+ "char_ids": torch.LongTensor(char_ids),
311
+ "bboxes": torch.FloatTensor(bboxes),
312
+ # "pattern_features" removed
313
+ "spatial_features": torch.FloatTensor(spatial_features),
314
+ "context_features": torch.FloatTensor(context_features),
315
+ "labels": torch.LongTensor(labels) if labels is not None else None,
316
+ "tokens": toks
317
+ }
318
+
319
+
320
+ def collate_batch(batch):
321
+ max_len = max(item["word_ids"].size(0) for item in batch)
322
+ batch_size = len(batch)
323
+
324
+ word_pad = torch.zeros((batch_size, max_len), dtype=torch.long)
325
+ char_pad = torch.zeros((batch_size, max_len, MAX_CHAR_LEN), dtype=torch.long)
326
+ bbox_pad = torch.zeros((batch_size, max_len, 4), dtype=torch.float)
327
+ # pattern_pad removed
328
+ spatial_pad = torch.zeros((batch_size, max_len, 11), dtype=torch.float) # Note: 11 spatial features
329
+ context_pad = torch.zeros((batch_size, max_len, 8), dtype=torch.float) # Note: 8 context features
330
+ mask = torch.zeros((batch_size, max_len), dtype=torch.bool)
331
+ label_pad = torch.full((batch_size, max_len), -1, dtype=torch.long)
332
+ tokens_list = []
333
+
334
+ for i, item in enumerate(batch):
335
+ L = item["word_ids"].size(0)
336
+ word_pad[i, :L] = item["word_ids"]
337
+ char_pad[i, :L, :] = item["char_ids"]
338
+ bbox_pad[i, :L, :] = item["bboxes"]
339
+ # pattern_pad removed
340
+ spatial_pad[i, :L, :] = item["spatial_features"]
341
+ context_pad[i, :L, :] = item["context_features"]
342
+ mask[i, :L] = 1
343
+ if item["labels"] is not None:
344
+ label_pad[i, :L] = item["labels"]
345
+ tokens_list.append(item["tokens"])
346
+
347
+ return {
348
+ "words": word_pad,
349
+ "chars": char_pad,
350
+ "bboxes": bbox_pad,
351
+ # "pattern_features" removed
352
+ "spatial_features": spatial_pad,
353
+ "context_features": context_pad,
354
+ "mask": mask,
355
+ "labels": label_pad,
356
+ "tokens": tokens_list
357
+ }
358
+
359
+
360
+ # ========== ENHANCED MODEL ==========
361
+ class CharCNNEncoder(nn.Module):
362
+ def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
363
+ super().__init__()
364
+ self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
365
+ convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes]
366
+ self.convs = nn.ModuleList(convs)
367
+ self.out_dim = out_dim * len(convs)
368
+
369
+ def forward(self, char_ids):
370
+ B, L, C = char_ids.size()
371
+ emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
372
+ outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
373
+ res = torch.cat(outs, dim=1)
374
+ return res.view(B, L, -1)
375
+
376
+
377
+ class SpatialAttention(nn.Module):
378
+ """Attention mechanism for spatial relationships."""
379
+
380
+ def __init__(self, hidden_dim):
381
+ super().__init__()
382
+ self.query = nn.Linear(hidden_dim, hidden_dim)
383
+ self.key = nn.Linear(hidden_dim, hidden_dim)
384
+ self.value = nn.Linear(hidden_dim, hidden_dim)
385
+ self.scale = hidden_dim ** 0.5
386
+
387
+ def forward(self, x, mask):
388
+ Q = self.query(x)
389
+ K = self.key(x)
390
+ V = self.value(x)
391
+
392
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
393
+
394
+ # Apply mask
395
+ mask_expanded = mask.unsqueeze(1).expand_as(scores)
396
+ scores = scores.masked_fill(~mask_expanded, float('-inf'))
397
+
398
+ attn_weights = F.softmax(scores, dim=-1)
399
+ # Handle NaN from softmax on all -inf scores (shouldn't happen with proper mask, but for safety)
400
+ attn_weights = attn_weights.masked_fill(torch.isnan(attn_weights), 0.0)
401
+
402
+ output = torch.matmul(attn_weights, V)
403
+ return output
404
+
405
+
406
+ class MCQTagger(nn.Module):
407
+ def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM):
408
+ super().__init__()
409
+ self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
410
+ self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
411
+
412
+ # Enhanced bbox encoding with MLP
413
+ self.bbox_proj = nn.Sequential(
414
+ nn.Linear(4, bbox_dim),
415
+ nn.ReLU(),
416
+ nn.Dropout(0.1),
417
+ nn.Linear(bbox_dim, bbox_dim)
418
+ )
419
+
420
+ # Feature projections (Pattern projection removed)
421
+ self.spatial_proj = nn.Sequential(
422
+ nn.Linear(11, SPATIAL_FEATURE_DIM), # 11 spatial features
423
+ nn.ReLU(),
424
+ nn.Dropout(0.1)
425
+ )
426
+ self.context_proj = nn.Sequential(
427
+ nn.Linear(8, 32), # 8 context features
428
+ nn.ReLU(),
429
+ nn.Dropout(0.1)
430
+ )
431
+
432
+ # Positional encoding for sequence position awareness
433
+ self.positional_encoding = nn.Embedding(512, POSITIONAL_DIM)
434
+
435
+ # Input dimension updated (PATTERN_FEATURE_DIM removed)
436
+ in_dim = (EMBED_DIM + self.char_enc.out_dim + bbox_dim +
437
+ SPATIAL_FEATURE_DIM + 32 + POSITIONAL_DIM)
438
+
439
+ # Deeper BiLSTM
440
+ self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=3,
441
+ batch_first=True, bidirectional=True, dropout=0.3)
442
+
443
+ # Spatial attention layer
444
+ self.spatial_attention = SpatialAttention(HIDDEN_SIZE)
445
+
446
+ # Layer normalization
447
+ self.layer_norm = nn.LayerNorm(HIDDEN_SIZE)
448
+
449
+ # Final projection with residual connection
450
+ self.ff = nn.Sequential(
451
+ nn.Linear(HIDDEN_SIZE * 2, HIDDEN_SIZE), # *2 for attention concat
452
+ nn.ReLU(),
453
+ nn.Dropout(0.3),
454
+ nn.Linear(HIDDEN_SIZE, n_labels)
455
+ )
456
+
457
+ self.crf = CRF(n_labels)
458
+ self.dropout = nn.Dropout(p=0.5)
459
+
460
+ def forward_emissions(self, words, chars, bboxes, spatial_feats, context_feats, mask):
461
+ B, L = words.size()
462
+
463
+ # Embeddings
464
+ wemb = self.word_emb(words)
465
+ cenc = self.char_enc(chars)
466
+ benc = self.bbox_proj(bboxes)
467
+ # penc removed
468
+ senc = self.spatial_proj(spatial_feats)
469
+ cxt_enc = self.context_proj(context_feats)
470
+
471
+ # Positional encoding
472
+ positions = torch.arange(L, device=words.device).unsqueeze(0).expand(B, -1)
473
+ pos_enc = self.positional_encoding(positions.clamp(max=511))
474
+
475
+ # Concatenate all features (penc removed)
476
+ enc_in = torch.cat([wemb, cenc, benc, senc, cxt_enc, pos_enc], dim=-1)
477
+ enc_in = self.dropout(enc_in)
478
+
479
+ # BiLSTM
480
+ lengths = mask.sum(dim=1).cpu()
481
+ packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
482
+ packed_out, _ = self.bilstm(packed_in)
483
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
484
+
485
+ # Spatial attention
486
+ attn_out = self.spatial_attention(lstm_out, mask)
487
+
488
+ # Combine LSTM and attention with residual
489
+ combined = torch.cat([lstm_out, attn_out], dim=-1)
490
+ combined = self.layer_norm(lstm_out + attn_out)
491
+
492
+ # Final projection
493
+ emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
494
+ return emissions
495
+
496
+ def forward(self, words, chars, bboxes, spatial_feats, context_feats, mask, labels=None,
497
+ class_weights=None, alpha=0.8):
498
+ # pattern_feats removed from arguments
499
+ emissions = self.forward_emissions(words, chars, bboxes, spatial_feats, context_feats, mask)
500
+
501
+ if labels is not None:
502
+ crf_loss = -self.crf(emissions, labels, mask=mask).sum()
503
+ if class_weights is not None:
504
+ # Use a combined loss for better learning, as intended by the previous code structure
505
+ ce_loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(emissions.device), ignore_index=-1)
506
+ ce_loss = ce_loss_fn(emissions.view(-1, emissions.size(-1)), labels.view(-1))
507
+ return alpha * crf_loss + (1 - alpha) * ce_loss
508
+ return crf_loss
509
+
510
+ return self.crf.viterbi_decode(emissions, mask=mask)
511
+
512
+
513
+ # ========== Training/Eval ==========
514
+ def compute_class_weights(labels_list, num_labels):
515
+ all_labels_flat = [lbl for page in labels_list for lbl in page]
516
+ counts = Counter(all_labels_flat)
517
+ total = sum(counts.values())
518
+ weights = []
519
+
520
+ for i in range(num_labels):
521
+ count = counts.get(i, 0)
522
+ w = total / (num_labels * count) if count > 0 else 1.0
523
+
524
+ weights.append(w)
525
+
526
+ return torch.tensor(weights, dtype=torch.float)
527
+
528
+
529
+ def eval_model(model, data_loader):
530
+ model.eval()
531
+ all_true, all_pred = [], []
532
+ with torch.no_grad():
533
+ for batch in data_loader:
534
+ words = batch["words"].to(DEVICE)
535
+ chars = batch["chars"].to(DEVICE)
536
+ bboxes = batch["bboxes"].to(DEVICE)
537
+ # pattern_feats removed
538
+ spatial_feats = batch["spatial_features"].to(DEVICE)
539
+ context_feats = batch["context_features"].to(DEVICE)
540
+ mask = batch["mask"].to(DEVICE)
541
+ labels = batch["labels"].to(DEVICE)
542
+
543
+ # pattern_feats removed from model call
544
+ preds_batch = model(words, chars, bboxes, spatial_feats, context_feats, mask, labels=None)
545
+
546
+ for i in range(len(preds_batch)):
547
+ L = len(preds_batch[i])
548
+ all_pred.extend(preds_batch[i])
549
+ all_true.extend(labels[i][:L].cpu().numpy().tolist())
550
+
551
+ from sklearn.metrics import precision_recall_fscore_support
552
+ # NOTE: Labels list excludes 'O' (0) for task-specific F1
553
+ p, r, f1, _ = precision_recall_fscore_support(all_true, all_pred, average='micro', zero_division=0,
554
+ labels=list(range(1, len(LABELS))))
555
+ return p, r, f1
556
+
557
+
558
+ # MODIFIED: Added CRITICAL FIX for OneCycleLR by saving/loading scheduler._step_count
559
+ def train_model(model, train_loader, val_loader, epochs=EPOCHS, class_weights=None,
560
+ initial_best_f1=0.0, start_epoch=1, model_path=None, checkpoint_path=None):
561
+ model.to(DEVICE)
562
+
563
+ # Use AdamW with weight decay for better generalization
564
+ optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
565
+
566
+ # Learning rate scheduler with warmup (Must be initialized BEFORE loading state_dict)
567
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
568
+ optim, max_lr=LR, epochs=epochs, steps_per_epoch=len(train_loader),
569
+ pct_start=0.1, anneal_strategy='cos'
570
+ )
571
+
572
+ # --- CHECKPOINT LOADING ---
573
+ best_val_f1 = initial_best_f1
574
+ if os.path.exists(checkpoint_path):
575
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
576
+
577
+ # NOTE: model weights were loaded in train_from_json, but we load again for safety
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ optim.load_state_dict(checkpoint['optimizer_state_dict'])
581
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
582
+
583
+ # CRITICAL FIX: Explicitly load the scheduler's internal step count
584
+ if '_step_count' in checkpoint:
585
+ scheduler._step_count = checkpoint['_step_count']
586
+
587
+ best_val_f1 = checkpoint['best_val_f1']
588
+ start_epoch = checkpoint['epoch'] + 1
589
+
590
+ print(f":floppy_disk: Resuming training from Epoch {start_epoch} with F1: {best_val_f1:.4f}")
591
+ # --- END CHECKPOINT LOADING ---
592
+
593
+ patience = 10
594
+ patience_counter = 0
595
+
596
+ for ep in range(start_epoch, epochs + 1):
597
+ model.train()
598
+ running_loss = 0.0
599
+ for batch in tqdm(train_loader, desc=f"Train E{ep}"):
600
+ optim.zero_grad()
601
+
602
+ words = batch["words"].to(DEVICE)
603
+ chars = batch["chars"].to(DEVICE)
604
+ bboxes = batch["bboxes"].to(DEVICE)
605
+ # pattern_feats removed
606
+ spatial_feats = batch["spatial_features"].to(DEVICE)
607
+ context_feats = batch["context_features"].to(DEVICE)
608
+ mask = batch["mask"].to(DEVICE)
609
+ labels = batch["labels"].to(DEVICE)
610
+
611
+ # pattern_feats removed from model call
612
+ loss = model(words, chars, bboxes, spatial_feats, context_feats, mask, labels,
613
+ class_weights=class_weights)
614
+ loss.backward()
615
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
616
+ optim.step()
617
+ scheduler.step() # This step will now be correctly tracked
618
+ running_loss += loss.item()
619
+
620
+ avg_loss = running_loss / max(1, len(train_loader))
621
+ print(f"Epoch {ep} train loss {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")
622
+
623
+ # Evaluate on validation set
624
+ p, r, f1 = eval_model(model, val_loader)
625
+ print(f"VAL p={p:.4f} r={r:.4f} f1={f1:.4f}")
626
+
627
+ if f1 > best_val_f1:
628
+ best_val_f1 = f1
629
+ patience_counter = 0
630
+
631
+ # Save the BEST MODEL (just state_dict for deployment)
632
+ torch.save(model.state_dict(), model_path)
633
+
634
+ # Save the FULL CHECKPOINT for resuming training (UPDATED to include _step_count)
635
+ torch.save({
636
+ 'epoch': ep,
637
+ 'model_state_dict': model.state_dict(),
638
+ 'optimizer_state_dict': optim.state_dict(),
639
+ 'scheduler_state_dict': scheduler.state_dict(),
640
+ 'best_val_f1': best_val_f1,
641
+ '_step_count': scheduler._step_count # CRITICAL FIX: Save the step count
642
+ }, checkpoint_path)
643
+
644
+ print(f":white_check_mark: New best model and checkpoint saved! F1: {best_val_f1:.4f}")
645
+ else:
646
+ patience_counter += 1
647
+ if patience_counter >= patience:
648
+ print(f"Early stopping triggered after {ep} epochs")
649
+ break
650
+
651
+ print("Training complete. Best val F1:", best_val_f1)
652
+
653
+
654
+ # ========== Helpers ==========
655
+ def build_vocabs(train_pages_tokens):
656
+ word_vocab = Vocab(min_freq=1)
657
+ char_vocab = Vocab(min_freq=1, unk_token="<CUNK>", pad_token="<CPAD>")
658
+
659
+ for p in train_pages_tokens:
660
+ for tok in p["tokens"]:
661
+ text_value = tok["text"]
662
+ word_vocab.add_sentence([text_value])
663
+ char_vocab.add_sentence(list(text_value[:MAX_CHAR_LEN]))
664
+
665
+ word_vocab.build()
666
+ char_vocab.build()
667
+
668
+ if len(word_vocab) <= 2:
669
+ raise ValueError(f"FATAL: Word vocabulary size is only {len(word_vocab)}.")
670
+
671
+ return word_vocab, char_vocab
672
+
673
+
674
+ def save_vocabs(path, word_vocab, char_vocab):
675
+ with open(path, "wb") as f:
676
+ pickle.dump((word_vocab, char_vocab), f)
677
+
678
+
679
+ def convert_labels_to_indices(all_labels):
680
+ return [[LABEL2IDX[l] for l in page] for page in all_labels]
681
+
682
+
683
+ # MODIFIED: train_from_json handles checkpoint loading setup
684
+ def train_from_json(unified_json_path: str):
685
+ print(":fire: Loading unified layout-aware labeled data...")
686
+ all_pages_tokens, all_labels = load_unified_data(unified_json_path)
687
+
688
+ if not all_labels:
689
+ raise RuntimeError(":x: No labeled data found. Please check your unified JSON file.")
690
+
691
+ print(f":bar_chart: Total dataset size: {len(all_labels)} samples (chunks)")
692
+
693
+ # Data splitting
694
+ split_idx = int(len(all_pages_tokens) * 0.8)
695
+ train_pages_tokens = all_pages_tokens[:split_idx]
696
+ train_labels = all_labels[:split_idx]
697
+ val_pages_tokens = all_pages_tokens[split_idx:]
698
+ val_labels = all_labels[split_idx:]
699
+
700
+ print(f":white_check_mark: Training on {len(train_labels)} samples, validating on {len(val_labels)} samples")
701
+
702
+ # Class weights calculation
703
+ all_labels_indices = convert_labels_to_indices(all_labels)
704
+ class_weights = compute_class_weights(all_labels_indices, len(LABELS)).to(DEVICE)
705
+ print(":1234: Class weights:", class_weights)
706
+
707
+ # Vocab building
708
+ vocab_path = os.path.join(DATA_DIR, VOCAB_FILE)
709
+ word_vocab, char_vocab = build_vocabs(train_pages_tokens)
710
+ print(f"DEBUG: Final word vocab size: {len(word_vocab)}")
711
+ save_vocabs(vocab_path, word_vocab, char_vocab)
712
+
713
+ # Dataloaders
714
+ dataset_train = MCQTokenDataset(train_pages_tokens, word_vocab, char_vocab, labels_per_token=train_labels)
715
+ dataset_val = MCQTokenDataset(val_pages_tokens, word_vocab, char_vocab, labels_per_token=val_labels)
716
+ train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
717
+ val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, collate_fn=collate_batch)
718
+
719
+ # Initialize model
720
+ model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS))
721
+
722
+ # --- CHECKPOINT SETUP ---
723
+ model_path = os.path.join(DATA_DIR, MODEL_FILE)
724
+ checkpoint_path = os.path.join(DATA_DIR, CHECKPOINT_FILE)
725
+ initial_best_f1 = 0.0
726
+ start_epoch = 1
727
+
728
+ # Load only model weights if checkpoint exists (to initialize the model before passing to train)
729
+ if os.path.exists(checkpoint_path):
730
+ try:
731
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
732
+ model.load_state_dict(checkpoint['model_state_dict'])
733
+ initial_best_f1 = checkpoint['best_val_f1']
734
+ start_epoch = checkpoint['epoch'] + 1
735
+ print(
736
+ f":floppy_disk: Full checkpoint found. Model weights loaded. Resuming setup from Epoch {start_epoch}.")
737
+ except Exception as e:
738
+ print(f":warning: Could not load full checkpoint: {e}. Starting from scratch.")
739
+
740
+ elif os.path.exists(model_path):
741
+ # Fallback: Load only model weights from the best F1 model file if no full checkpoint
742
+ print(f":floppy_disk: Found best model weights at {model_path}. Loading weights...")
743
+ try:
744
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
745
+ except RuntimeError as e:
746
+ print(f":warning: Could not load model state: {e}. Starting fresh.")
747
+
748
+ else:
749
+ print(":rocket: Starting training from scratch (no model or checkpoint found)...")
750
+ # --- END CHECKPOINT SETUP ---
751
+
752
+ print(f":triangular_ruler: Model parameters: {sum(p.numel() for p in model.parameters()):,}")
753
+
754
+ # Pass paths, initial F1, and start epoch to train_model
755
+ train_model(model, train_loader, val_loader, epochs=EPOCHS, class_weights=class_weights,
756
+ initial_best_f1=initial_best_f1, start_epoch=start_epoch,
757
+ model_path=model_path, checkpoint_path=checkpoint_path)
758
+
759
+ print("\n:white_check_mark: Training complete.")
760
+ print(f":package: Best Model weights saved to: {model_path}")
761
+ print(f":package: Vocabularies saved to: {vocab_path}")
762
+
763
+
764
+ # ========== MAIN EXECUTION BLOCK ==========
765
+ if __name__ == "__main__":
766
+ parser = argparse.ArgumentParser(
767
+ description="Train an enhanced BiLSTM-CRF model with deep layout understanding for MCQ structure extraction.")
768
+ parser.add_argument(
769
+ "unified_json_path",
770
+ type=str,
771
+ help="Path to the unified JSON file containing token, bbox, and label data."
772
+ )
773
+ args = parser.parse_args()
774
+
775
+ train_from_json(args.unified_json_path)
776
+
777
+