deepkansara-123 commited on
Commit
6d7038d
ยท
verified ยท
1 Parent(s): 4821abe

Upload charcnn_bylstm.py

Browse files
Files changed (1) hide show
  1. charcnn_bylstm.py +730 -0
charcnn_bylstm.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mcq_extractor_updated.py
2
+ import os
3
+ import re
4
+ import io
5
+ import json
6
+ import math
7
+ import pickle
8
+ from collections import Counter, defaultdict
9
+ from typing import List, Tuple
10
+
11
+ import fitz # PyMuPDF
12
+ import pytesseract
13
+ from PIL import Image
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader
21
+ # use the same CRF you had
22
+ from TorchCRF import CRF # pip install torchcrf
23
+
24
+ # ========== CONFIG ==========
25
+ DATA_DIR = "output_data"
26
+ IMAGES_DIR = os.path.join(DATA_DIR, "images")
27
+ os.makedirs(IMAGES_DIR, exist_ok=True)
28
+ PAGE_OCR_CHAR_THRESHOLD = 300
29
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ MAX_CHAR_LEN = 16
31
+ EMBED_DIM = 100
32
+ CHAR_EMBED_DIM = 30
33
+ CHAR_CNN_OUT = 30
34
+ HIDDEN_SIZE = 256
35
+ BATCH_SIZE = 8
36
+ EPOCHS = 50
37
+ LR = 1e-3
38
+
39
+ pytesseract.pytesseract.tesseract_cmd = r"D:\prince\New folder\tesseract.exe"
40
+
41
+ # ========== LABELS (single source of truth) ==========
42
+ LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER",'B-IMAGE','I-IMAGE']
43
+ LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
44
+ IDX2LABEL = {i: l for l, i in LABEL2IDX.items()}
45
+
46
+
47
+ # ---------- small utility classes ----------
48
+ class Vocab:
49
+ def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
50
+ self.min_freq = min_freq
51
+ self.unk_token = unk_token
52
+ self.pad_token = pad_token
53
+ self.freq = Counter()
54
+ self.itos = []
55
+ self.stoi = {}
56
+
57
+ def add_sentence(self, toks):
58
+ self.freq.update(toks)
59
+
60
+ def build(self):
61
+ items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
62
+ items = [self.pad_token, self.unk_token] + sorted(items)
63
+ self.itos = items
64
+ self.stoi = {s: i for i, s in enumerate(self.itos)}
65
+
66
+ def __len__(self):
67
+ return len(self.itos)
68
+
69
+
70
+ # ========== PDF / tokenization utils (keep yours, slightly cleaned) ==========
71
+ def clean_text_token(t):
72
+ """Normalizes special characters in a token."""
73
+ return t.replace("\u2011", "-") # normalize hyphen
74
+
75
+
76
+ PAGE_OCR_CHAR_THRESHOLD = 50
77
+
78
+ def extract_pdf_pages(path: str):
79
+ """
80
+ Extracts content from PDF pages.
81
+ Returns a list of pages with:
82
+ - 'width', 'height' -> page dimensions
83
+ - 'blocks' -> text blocks with bbox
84
+ - 'images' -> images with bbox and PIL image
85
+ """
86
+ if not os.path.exists(path):
87
+ raise FileNotFoundError(f"The file was not found: {path}")
88
+
89
+ doc = fitz.open(path)
90
+ pages = []
91
+
92
+ for pno, page in enumerate(doc):
93
+ w, h = page.rect.width, page.rect.height
94
+
95
+ # Extract text blocks
96
+ raw_blocks = page.get_text("blocks", sort=True)
97
+ text_blocks = []
98
+ for b in raw_blocks:
99
+ x0, y0, x1, y1, text, block_no, block_type = b
100
+ if block_type != 0: # 0 = text block
101
+ continue
102
+ text = text.strip().replace("\n", " ")
103
+ if text:
104
+ text_blocks.append({
105
+ "bbox": (x0, y0, x1, y1),
106
+ "text": text,
107
+ "font_size": None # can optionally extract from span if needed
108
+ })
109
+
110
+ # Extract images
111
+ images = []
112
+ for img_info in page.get_images(full=True):
113
+ xref = img_info[0]
114
+ try:
115
+ base_image = doc.extract_image(xref)
116
+ img_bytes = base_image["image"]
117
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
118
+ img_rect = page.get_image_bbox(img_info)
119
+ images.append({"bbox": (img_rect.x0, img_rect.y0, img_rect.x1, img_rect.y1), "image": img})
120
+ except Exception as e:
121
+ print(f"Warning: Could not extract image {xref} on page {pno+1}. Error: {e}")
122
+
123
+ # OCR fallback if text is too little
124
+ total_chars = sum(len(b["text"]) for b in text_blocks)
125
+ if total_chars < PAGE_OCR_CHAR_THRESHOLD:
126
+ pix = page.get_pixmap(dpi=300)
127
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
128
+ ocr_text = pytesseract.image_to_string(img)
129
+ if ocr_text.strip():
130
+ text_blocks = [{"bbox": (0, 0, w, h), "text": ocr_text.strip(), "font_size": None}]
131
+
132
+ pages.append({"width": w, "height": h, "blocks": text_blocks, "images": images})
133
+
134
+ doc.close()
135
+ return pages
136
+
137
+
138
+ import re
139
+
140
+ IMAGES_DIR = "output_data/images"
141
+
142
+ def split_blocks_into_tokens(pages):
143
+ """
144
+ Tokenizes text blocks and sorts them based on page layout (single or two-column).
145
+ Returns a list of pages, each containing a list of token dicts.
146
+ """
147
+ token_re = re.compile(r"\w+|[^\w\s]", re.UNICODE)
148
+ all_pages_tokens = []
149
+
150
+ for pidx, page in enumerate(pages):
151
+ tokens = []
152
+ page_w, page_h = page["width"], page["height"]
153
+ mid_x = page_w / 2
154
+
155
+ # Detect if page is two-column
156
+ left_count, right_count, spanning_count = 0, 0, 0
157
+ gutter = 0.1 * page_w
158
+ for b in page["blocks"]:
159
+ x0, y0, x1, y1 = b["bbox"]
160
+ if y0 < 0.05 * page_h or y1 > 0.95 * page_h: # ignore headers/footers
161
+ continue
162
+ if x1 < mid_x - gutter/2:
163
+ left_count += 1
164
+ elif x0 > mid_x + gutter/2:
165
+ right_count += 1
166
+ elif x0 < mid_x and x1 > mid_x:
167
+ spanning_count += 1
168
+ is_two_column = left_count > 3 and right_count > 3 and spanning_count <= 2
169
+
170
+ # Tokenize blocks
171
+ for bidx, block in enumerate(page["blocks"]):
172
+ x0, y0, x1, y1 = block["bbox"]
173
+ text = block["text"].replace("\u00ad", "")
174
+ toks = token_re.findall(text)
175
+ if not toks:
176
+ continue
177
+ total_chars = sum(len(t) for t in toks)
178
+ cur_x = x0
179
+ for tok in toks:
180
+ tok_width = (len(tok)/total_chars)* (x1 - x0) if total_chars>0 else (x1-x0)/len(toks)
181
+ tokens.append({
182
+ "text": clean_text_token(tok),
183
+ "x0": cur_x, "y0": y0,
184
+ "x1": cur_x + tok_width, "y1": y1,
185
+ "font_size": block.get("font_size"),
186
+ "page_no": pidx+1,
187
+ "block_idx": bidx
188
+ })
189
+ cur_x += tok_width
190
+
191
+ # Sort tokens based on layout
192
+ if is_two_column:
193
+ tokens.sort(key=lambda t: (0 if t['x0'] < mid_x else 1, t['y0'], t['x0']))
194
+ else:
195
+ tokens.sort(key=lambda t: (t['y0'], t['x0']))
196
+
197
+ all_pages_tokens.append(tokens)
198
+ return all_pages_tokens
199
+
200
+
201
+ def assign_images_to_tokens(pages, all_pages_tokens):
202
+ """
203
+ Inserts image placeholders into the token stream.
204
+ """
205
+ if not os.path.exists(IMAGES_DIR):
206
+ os.makedirs(IMAGES_DIR)
207
+
208
+ for pidx, page in enumerate(pages):
209
+ tokens = all_pages_tokens[pidx]
210
+ for img_idx, imrec in enumerate(page["images"]):
211
+ img_name = f"page{pidx+1}_img{img_idx+1}.png"
212
+ imrec["image"].save(os.path.join(IMAGES_DIR, img_name))
213
+ img_center_y = (imrec["bbox"][1]+imrec["bbox"][3])/2
214
+ if not tokens:
215
+ insert_idx = 0
216
+ else:
217
+ closest_token = min(tokens, key=lambda t: abs((t["y0"]+t["y1"])/2 - img_center_y))
218
+ insert_idx = tokens.index(closest_token)+1
219
+ tokens.insert(insert_idx, {
220
+ "text": f"[IMAGE: {img_name}]",
221
+ "x0": imrec["bbox"][0], "y0": imrec["bbox"][1],
222
+ "x1": imrec["bbox"][2], "y1": imrec["bbox"][3],
223
+ "font_size": None,
224
+ "page_no": pidx+1,
225
+ "block_idx": -1,
226
+ "is_image": True
227
+ })
228
+ all_pages_tokens[pidx] = tokens
229
+ return all_pages_tokens
230
+
231
+
232
+
233
+
234
+ # ========== Dataset ==========
235
+ def orthographic_features(token_text):
236
+ return [
237
+ int(token_text[0].isupper()) if token_text and token_text[0].isalpha() else 0,
238
+ int(token_text.isupper()),
239
+ int(any(ch.isdigit() for ch in token_text)),
240
+ int(len(token_text) == 1 and re.match(r'\W', token_text) is not None)
241
+ ]
242
+
243
+ class MCQTokenDataset(Dataset):
244
+ def __init__(self, pages_tokens, word_vocab, char_vocab, labels_per_token=None):
245
+ self.samples = []
246
+ self.labels = []
247
+
248
+ if labels_per_token:
249
+ for toks, lbls in zip(pages_tokens, labels_per_token):
250
+ if len(toks) == 0:
251
+ continue # skip empty pages
252
+ if len(toks) != len(lbls):
253
+ raise ValueError(f"Token/label length mismatch: {len(toks)} vs {len(lbls)}")
254
+ self.samples.append(toks)
255
+ self.labels.append(lbls)
256
+ else:
257
+ self.samples = [p for p in pages_tokens if len(p) > 0]
258
+ self.word_vocab = word_vocab
259
+ self.char_vocab = char_vocab
260
+
261
+ def __len__(self):
262
+ return len(self.samples)
263
+
264
+ def __getitem__(self, idx):
265
+ toks = self.samples[idx]
266
+
267
+ # โœ… Make sure every token has text
268
+ words = []
269
+ safe_toks = []
270
+ for t in toks:
271
+ if isinstance(t, dict) and "text" in t:
272
+ txt = t["text"]
273
+ safe_toks.append(t)
274
+ elif isinstance(t, str):
275
+ txt = t
276
+ safe_toks.append({"text": txt, "x0": 0, "x1": 0, "y0": 0, "y1": 0, "font_size": 0.0})
277
+ else:
278
+ txt = str(t)
279
+ safe_toks.append({"text": txt, "x0": 0, "x1": 0, "y0": 0, "y1": 0, "font_size": 0.0})
280
+ words.append(txt)
281
+
282
+ toks = safe_toks # Use normalized tokens downstream
283
+
284
+ word_ids = [self.word_vocab.stoi.get(w, self.word_vocab.stoi[self.word_vocab.unk_token]) for w in words]
285
+
286
+ char_ids = []
287
+ for w in words:
288
+ chs = [self.char_vocab.stoi.get(ch, self.char_vocab.stoi[self.char_vocab.unk_token]) for ch in
289
+ w[:MAX_CHAR_LEN]]
290
+ if len(chs) < MAX_CHAR_LEN:
291
+ chs += [self.char_vocab.stoi[self.char_vocab.pad_token]] * (MAX_CHAR_LEN - len(chs))
292
+ char_ids.append(chs)
293
+
294
+ x_centers = [(t["x0"] + t["x1"]) / 2.0 for t in toks]
295
+ y_centers = [(t["y0"] + t["y1"]) / 2.0 for t in toks]
296
+ max_x = max([t["x1"] for t in toks]) if toks else 1.0
297
+ max_y = max([t["y1"] for t in toks]) if toks else 1.0
298
+
299
+ if max_x == 0:
300
+ max_x = 1.0
301
+ if max_y == 0:
302
+ max_y = 1.0
303
+
304
+ x_norm = [xc / max_x for xc in x_centers]
305
+ y_norm = [yc / max_y for yc in y_centers]
306
+
307
+ font_sizes = [float(t.get("font_size") or 0.0) for t in toks]
308
+ ortho_feats = [orthographic_features(w) for w in words]
309
+
310
+ labels = None
311
+ if self.labels:
312
+ lbls = self.labels[idx]
313
+ labels = [LABEL2IDX[l] for l in lbls]
314
+
315
+ return {
316
+ "word_ids": torch.LongTensor(word_ids),
317
+ "char_ids": torch.LongTensor(char_ids),
318
+ "x_norm": torch.FloatTensor(x_norm),
319
+ "y_norm": torch.FloatTensor(y_norm),
320
+ "font_sizes": torch.FloatTensor(font_sizes),
321
+ "ortho": torch.FloatTensor(ortho_feats),
322
+ "labels": torch.LongTensor(labels) if labels is not None else None,
323
+ "tokens": toks
324
+ }
325
+
326
+
327
+ def collate_batch(batch):
328
+ batch = [item for item in batch if item["word_ids"].size(0) > 0] # remove empty sequences
329
+ if len(batch) == 0:
330
+ return None # or raise error
331
+
332
+ max_len = max(item["word_ids"].size(0) for item in batch)
333
+ batch_size = len(batch)
334
+
335
+ word_pad = torch.zeros((batch_size, max_len), dtype=torch.long)
336
+ char_pad = torch.zeros((batch_size, max_len, MAX_CHAR_LEN), dtype=torch.long)
337
+ x_pad = torch.zeros((batch_size, max_len), dtype=torch.float)
338
+ y_pad = torch.zeros((batch_size, max_len), dtype=torch.float)
339
+ font_pad = torch.zeros((batch_size, max_len), dtype=torch.float)
340
+ ortho_pad = torch.zeros((batch_size, max_len, 4), dtype=torch.float)
341
+ mask = torch.zeros((batch_size, max_len), dtype=torch.bool)
342
+ label_pad = torch.full((batch_size, max_len), LABEL2IDX["O"], dtype=torch.long) # use O as default
343
+ tokens_list = []
344
+
345
+ for i, item in enumerate(batch):
346
+ L = item["word_ids"].size(0)
347
+ word_pad[i, :L] = item["word_ids"]
348
+ char_pad[i, :L, :] = item["char_ids"]
349
+ x_pad[i, :L] = item["x_norm"]
350
+ y_pad[i, :L] = item["y_norm"]
351
+ font_pad[i, :L] = item["font_sizes"]
352
+ ortho_pad[i, :L, :] = item["ortho"]
353
+ mask[i, :L] = 1
354
+ if item["labels"] is not None and item["labels"].size(0) == L:
355
+ label_pad[i, :L] = item["labels"]
356
+ tokens_list.append(item["tokens"])
357
+
358
+ return {
359
+ "words": word_pad,
360
+ "chars": char_pad,
361
+ "x": x_pad,
362
+ "y": y_pad,
363
+ "font": font_pad,
364
+ "ortho": ortho_pad,
365
+ "mask": mask,
366
+ "labels": label_pad,
367
+ "tokens": tokens_list
368
+ }
369
+
370
+ # ========== MODEL ==========
371
+ class CharCNNEncoder(nn.Module):
372
+ def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3,4,5)):
373
+ super().__init__()
374
+ self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
375
+ convs = []
376
+ for k in kernel_sizes:
377
+ convs.append(nn.Conv1d(char_emb_dim, out_dim, kernel_size=k))
378
+ self.convs = nn.ModuleList(convs)
379
+ self.out_dim = out_dim * len(convs)
380
+
381
+ def forward(self, char_ids):
382
+ B, L, C = char_ids.size()
383
+ emb = self.char_emb(char_ids.view(B * L, C))
384
+ emb = emb.transpose(1,2)
385
+ outs = []
386
+ for conv in self.convs:
387
+ c = conv(emb)
388
+ c = torch.relu(c)
389
+ c = torch.max(c, dim=2)[0]
390
+ outs.append(c)
391
+ res = torch.cat(outs, dim=1)
392
+ return res.view(B, L, -1)
393
+
394
+ class MCQTagger(nn.Module):
395
+ def __init__(self, vocab_size, char_vocab_size, n_labels):
396
+ super().__init__()
397
+ self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
398
+ self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
399
+ in_dim = EMBED_DIM + self.char_enc.out_dim + 2 + 1 + 4
400
+ self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=1, batch_first=True, bidirectional=True)
401
+ self.ff = nn.Linear(HIDDEN_SIZE, n_labels)
402
+ self.crf = CRF(n_labels, batch_first=True)
403
+
404
+ def forward_emissions(self, words, chars, x, y, font, ortho, mask):
405
+ # return raw emissions (before CRF) so we can obtain per-token probs
406
+ wemb = self.word_emb(words)
407
+ cenc = self.char_enc(chars)
408
+ numeric = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), font.unsqueeze(-1), ortho], dim=-1)
409
+ enc_in = torch.cat([wemb, cenc, numeric], dim=-1)
410
+ packed_out, _ = self.bilstm(enc_in)
411
+ emissions = self.ff(packed_out)
412
+ return emissions
413
+
414
+ def forward(self, words, chars, x, y, font, ortho, mask, labels=None, class_weights=None, alpha=0.7):
415
+ emissions = self.forward_emissions(words, chars, x, y, font, ortho, mask)
416
+ if labels is not None:
417
+ crf_loss = -self.crf(emissions, labels, mask=mask, reduction='mean')
418
+ if class_weights is not None:
419
+ ce_loss_fn = nn.CrossEntropyLoss(weight=class_weights.to(emissions.device), ignore_index=-1)
420
+ ce_loss = ce_loss_fn(emissions.view(-1, emissions.size(-1)), labels.view(-1))
421
+ loss = alpha * crf_loss + (1 - alpha) * ce_loss
422
+ else:
423
+ loss = crf_loss
424
+ return loss
425
+ else:
426
+ pred = self.crf.decode(emissions, mask=mask)
427
+ return pred
428
+
429
+ # helper: get softmax probs per token from emissions
430
+ def emissions_to_probs(emissions, mask):
431
+ # emissions: (B, L, C)
432
+ probs = F.softmax(emissions, dim=-1) # (B,L,C)
433
+ probs = probs.cpu().numpy()
434
+ masks = mask.cpu().numpy()
435
+ # return as list of arrays per example (only active tokens)
436
+ out = []
437
+ for i in range(probs.shape[0]):
438
+ L = masks[i].sum()
439
+ out.append(probs[i][:L])
440
+ return out
441
+
442
+ # ========== training/eval ==========
443
+ def compute_class_weights(labels_list, num_labels):
444
+ all_labels_flat = [lbl for page in labels_list for lbl in page]
445
+ counts = Counter(all_labels_flat)
446
+ total = sum(counts.values())
447
+ weights = []
448
+ for i in range(num_labels):
449
+ count = counts.get(i, 0)
450
+ if count == 0:
451
+ w = 1.0
452
+ else:
453
+ w = total / (num_labels * count)
454
+ if IDX2LABEL[i] in ["B-QUESTION", "B-OPTION"]:
455
+ w *= 2.0
456
+ weights.append(w)
457
+ return torch.tensor(weights, dtype=torch.float)
458
+
459
+ def eval_model(model, data_loader):
460
+ model.eval()
461
+ all_true = []
462
+ all_pred = []
463
+ with torch.no_grad():
464
+ for batch in tqdm(data_loader, desc="Eval"):
465
+ words = batch["words"].to(DEVICE)
466
+ chars = batch["chars"].to(DEVICE)
467
+ x = batch["x"].to(DEVICE)
468
+ y = batch["y"].to(DEVICE)
469
+ font = batch["font"].to(DEVICE)
470
+ ortho = batch["ortho"].to(DEVICE)
471
+ mask = batch["mask"].to(DEVICE)
472
+ labels = batch["labels"].to(DEVICE)
473
+ preds = model(words, chars, x, y, font, ortho, mask, labels=None)
474
+ for i in range(len(preds)):
475
+ L = mask[i].sum().item()
476
+ pred_seq = preds[i][:L]
477
+ true_seq = labels[i][:L].cpu().numpy().tolist()
478
+ all_pred.extend(pred_seq)
479
+ all_true.extend(true_seq)
480
+ # compute token-level micro F1 excluding O maybe; here we compute micro across all labels
481
+ from sklearn.metrics import precision_recall_fscore_support
482
+ p, r, f1, _ = precision_recall_fscore_support(all_true, all_pred, average='micro', zero_division=0)
483
+ return p, r, f1
484
+
485
+ def train_model(model, train_loader, val_loader, epochs=EPOCHS, class_weights=None):
486
+ model.to(DEVICE)
487
+ optim = torch.optim.Adam(model.parameters(), lr=LR)
488
+ best_val_f1 = 0.0
489
+ for ep in range(1, epochs+1):
490
+ model.train()
491
+ running_loss = 0.0
492
+ for batch in tqdm(train_loader, desc=f"Train E{ep}"):
493
+ optim.zero_grad()
494
+ words = batch["words"].to(DEVICE)
495
+ chars = batch["chars"].to(DEVICE)
496
+ x = batch["x"].to(DEVICE)
497
+ y = batch["y"].to(DEVICE)
498
+ font = batch["font"].to(DEVICE)
499
+ ortho = batch["ortho"].to(DEVICE)
500
+ mask = batch["mask"].to(DEVICE)
501
+ labels = batch["labels"].to(DEVICE)
502
+ loss = model(words, chars, x, y, font, ortho, mask, labels, class_weights=class_weights)
503
+ loss.backward()
504
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
505
+ optim.step()
506
+ running_loss += loss.item()
507
+ avg_loss = running_loss / max(1, len(train_loader))
508
+ print(f"Epoch {ep} train loss {avg_loss:.4f}")
509
+ p, r, f1 = eval_model(model, val_loader)
510
+ print(f"VAL p={p:.4f} r={r:.4f} f1={f1:.4f}")
511
+ if f1 > best_val_f1:
512
+ best_val_f1 = f1
513
+ torch.save(model.state_dict(), os.path.join(DATA_DIR, "best_mcq_tagger.pt"))
514
+ print("Training complete. Best val F1:", best_val_f1)
515
+ return model
516
+
517
+ # ========== helpers to save/load vocabs ==========
518
+ def build_vocabs(pages_tokens):
519
+ word_vocab = Vocab(min_freq=1)
520
+ char_vocab = Vocab(min_freq=1, unk_token="<CUNK>", pad_token="<CPAD>")
521
+
522
+ for p in pages_tokens:
523
+ for tok in p:
524
+ # โœ… Always convert to string safely
525
+ if isinstance(tok, dict) and "text" in tok:
526
+ text_value = tok["text"]
527
+ elif isinstance(tok, str):
528
+ text_value = tok
529
+ else:
530
+ text_value = str(tok)
531
+
532
+ word_vocab.add_sentence([text_value])
533
+ for ch in text_value[:MAX_CHAR_LEN]:
534
+ char_vocab.add_sentence([ch])
535
+
536
+ word_vocab.build()
537
+ char_vocab.build()
538
+ return word_vocab, char_vocab
539
+
540
+ def save_vocabs(path, word_vocab, char_vocab):
541
+ with open(path, "wb") as f:
542
+ pickle.dump((word_vocab, char_vocab), f)
543
+
544
+ def load_vocabs(path):
545
+ with open(path, "rb") as f:
546
+ return pickle.load(f)
547
+
548
+ # ========== reconstruction (unchanged) ==========
549
+ def reconstruct_mcqs_from_tokens(tokens, preds):
550
+ mcqs = []
551
+ i = 0
552
+ N = len(tokens)
553
+ fragments = []
554
+ while i < N:
555
+ label = IDX2LABEL[preds[i]]
556
+ if label.startswith("B-QUESTION"):
557
+ if fragments and "question" in fragments[-1]:
558
+ mcqs.append(fragments[-1])
559
+ q_toks = [tokens[i]["text"]]
560
+ i += 1
561
+ while i < N and IDX2LABEL[preds[i]].startswith("I-QUESTION"):
562
+ q_toks.append(tokens[i]["text"])
563
+ i += 1
564
+ fragments.append({"question": " ".join(q_toks), "options": [], "answer": None})
565
+ elif fragments:
566
+ lab = IDX2LABEL[preds[i]]
567
+ if lab.startswith("B-OPTION"):
568
+ otoks = [tokens[i]["text"]]
569
+ i += 1
570
+ while i < N and IDX2LABEL[preds[i]].startswith("I-OPTION"):
571
+ otoks.append(tokens[i]["text"])
572
+ i += 1
573
+ fragments[-1]["options"].append(" ".join(otoks))
574
+ elif lab.startswith("B-ANSWER"):
575
+ atoks = [tokens[i]["text"]]
576
+ i += 1
577
+ while i < N and IDX2LABEL[preds[i]].startswith("I-ANSWER"):
578
+ atoks.append(tokens[i]["text"])
579
+ i += 1
580
+ fragments[-1]["answer"] = " ".join(atoks)
581
+ else:
582
+ i += 1
583
+ else:
584
+ i += 1
585
+
586
+ if fragments and "question" in fragments[-1]:
587
+ mcqs.append(fragments[-1])
588
+
589
+ # โœ… filter only "perfect" mcqs: must have a question and at least one option
590
+ mcqs = [m for m in mcqs if m.get("question") and m.get("options")]
591
+
592
+ return mcqs
593
+
594
+ def convert_labels_to_indices(all_labels):
595
+ all_labels_indices = [
596
+ [LABEL2IDX[l] for l in page] for page in all_labels
597
+ ]
598
+ return all_labels_indices
599
+ def demo_inference(pdf_path, model_path, vocab_path):
600
+ import json
601
+ from torch.utils.data import DataLoader
602
+
603
+ # Load vocabs
604
+ word_vocab, char_vocab = load_vocabs(vocab_path)
605
+
606
+ # Load model
607
+ model = MCQTagger(len(word_vocab), len(char_vocab), n_labels=len(LABELS))
608
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
609
+ model.to(DEVICE)
610
+ model.eval()
611
+
612
+ # Extract + tokenize PDF
613
+ pages = extract_pdf_pages(pdf_path)
614
+ pages_tokens = split_blocks_into_tokens(pages)
615
+ pages_tokens = assign_images_to_tokens(pages, pages_tokens)
616
+
617
+ # Dataset + loader
618
+ dataset = MCQTokenDataset(pages_tokens, word_vocab, char_vocab, labels_per_token=None)
619
+ loader = DataLoader(dataset, batch_size=1, collate_fn=collate_batch)
620
+
621
+ all_mcqs = []
622
+ all_preds = []
623
+ with torch.no_grad():
624
+ for batch in loader:
625
+ words = batch["words"].to(DEVICE)
626
+ chars = batch["chars"].to(DEVICE)
627
+ x = batch["x"].to(DEVICE)
628
+ y = batch["y"].to(DEVICE)
629
+ font = batch["font"].to(DEVICE)
630
+ ortho = batch["ortho"].to(DEVICE)
631
+ mask = batch["mask"].to(DEVICE)
632
+ tokens = batch["tokens"][0]
633
+
634
+ preds = model(words, chars, x, y, font, ortho, mask, labels=None)
635
+ preds = preds[0] # batch size = 1
636
+ all_preds.append(preds)
637
+
638
+ mcqs = reconstruct_mcqs_from_tokens(tokens, preds)
639
+ all_mcqs.extend(mcqs)
640
+
641
+ # Save to JSON (optional)
642
+ out_path = os.path.join(DATA_DIR, f"cnn_{os.path.basename(pdf_path)}.json")
643
+ with open(out_path, "w", encoding="utf-8") as f:
644
+ json.dump(all_mcqs, f, ensure_ascii=False, indent=2)
645
+
646
+ print(f"โœ… Results saved to {out_path}")
647
+ return all_mcqs, all_preds
648
+
649
+ # if run as script, keep legacy demo functions etc. (omitted for brevity)
650
+ if __name__ == "__main__":
651
+
652
+
653
+ #("augmented_data/english_21-50_labels_augmented_1.json", "augmented_data/english_21-50_labels_augmented_1.json"),
654
+ #("augmented_data/english_21-50_tokens_augmented_2.json", "augmented_data/english_21-50_labels_augmented_2.json"),
655
+ #("augmented_data/english_21-50_tokens_augmented_3.json", "augmented_data/english_21-50_labels_augmented_3.json"),
656
+ #("augmented_data/english_21-50_tokens_augmented_4.json", "augmented_data/english_21-50_labels_augmented_4.json"),
657
+ #("augmented_data/english_21-50_tokens_augmented_5.json", "augmented_data/english_21-50_labels_augmented_5.json")
658
+ with open("merged_tokens_labels.json", "r", encoding="utf-8") as f:
659
+ merged_data = json.load(f)
660
+ all_pages_tokens=[]
661
+ all_labels=[]
662
+ # group by page if needed โ€” assuming all tokens are from one page,
663
+ # otherwise you can group by "page_no"
664
+ from itertools import groupby
665
+
666
+ merged_data.sort(key=lambda x: x.get("page_no", 0))
667
+ pages = []
668
+ for page_no, group in groupby(merged_data, key=lambda x: x.get("page_no", 0)):
669
+ group = list(group)
670
+ tokens = []
671
+ labels = []
672
+ for item in group:
673
+ tokens.append({
674
+ "text": item.get("text", ""),
675
+ "x0": item.get("x0", 0),
676
+ "y0": item.get("y0", 0),
677
+ "x1": item.get("x1", 0),
678
+ "y1": item.get("y1", 0),
679
+ "font_size": item.get("font_size", 0),
680
+ "page_no": item.get("page_no", 0),
681
+ "block_idx": item.get("block_idx", 0)
682
+ })
683
+ labels.append(item.get("label", "O"))
684
+ all_pages_tokens.append(tokens)
685
+ all_labels.append(labels)
686
+
687
+ # ๐Ÿ”€ Split into training and validation
688
+ split_idx = int(len(all_pages_tokens) * 0.8)
689
+ train_pages_tokens = all_pages_tokens[:split_idx]
690
+ train_labels = all_labels[:split_idx]
691
+ val_pages_tokens = all_pages_tokens[split_idx:]
692
+ val_labels = all_labels[split_idx:]
693
+
694
+ print(f"Training on {len(train_labels)} pages, validating on {len(val_labels)} pages")
695
+
696
+ # ๐Ÿงฎ Compute class weights
697
+ all_labels_indices = convert_labels_to_indices(all_labels)
698
+ class_weights = compute_class_weights(all_labels_indices, len(LABELS)).to(DEVICE)
699
+ print("Class weights:", class_weights)
700
+
701
+
702
+ # ๐Ÿ—๏ธ Build vocabularies
703
+ word_vocab, char_vocab = build_vocabs(train_pages_tokens)
704
+
705
+ # ๐Ÿ“ฆ Build datasets
706
+ dataset_train = MCQTokenDataset(train_pages_tokens, word_vocab, char_vocab, labels_per_token=train_labels)
707
+ dataset_val = MCQTokenDataset(val_pages_tokens, word_vocab, char_vocab, labels_per_token=val_labels)
708
+
709
+ # ๐Ÿ”„ Data loaders
710
+ train_loader = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
711
+ val_loader = DataLoader(dataset_val, batch_size=BATCH_SIZE, collate_fn=collate_batch)
712
+
713
+ # ๐Ÿง  Train model
714
+ model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS))
715
+ train_model(model, train_loader, val_loader, epochs=EPOCHS, class_weights=class_weights)
716
+
717
+ # ๐Ÿ’พ Save vocabs for later inference
718
+ os.makedirs(DATA_DIR, exist_ok=True)
719
+ with open(os.path.join(DATA_DIR, "vocabs.pkl"), "wb") as f:
720
+ pickle.dump((word_vocab, char_vocab), f)
721
+ train_loader = DataLoader(dataset_train, batch_size=2, shuffle=True, collate_fn=collate_batch)
722
+
723
+ # Debug: check if rare labels appear in a batch
724
+ for batch in train_loader:
725
+ labels_in_batch = batch['labels'] # adjust key based on your dataset collate
726
+ unique_labels = torch.unique(torch.cat([torch.tensor([0, 1]), torch.tensor([2, 3])]))
727
+ print("Labels in batch:", unique_labels)
728
+ break
729
+ print("โœ… Training finished. Model + vocabs saved.")
730
+