aagamjtdev commited on
Commit
5b8a1d1
Β·
1 Parent(s): 4acd43e

Fix: Full Script, structured data

Browse files
Files changed (1) hide show
  1. app.py +1195 -107
app.py CHANGED
@@ -1,3 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import pickle
@@ -6,29 +980,29 @@ from collections import Counter
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from tqdm import tqdm
10
 
11
  # === GRADIO AND DEPENDENCIES ===
12
  import gradio as gr
13
  import fitz # PyMuPDF
14
- import re
15
  from PIL import Image, ImageEnhance
16
  import pytesseract
17
 
18
  try:
 
19
  from TorchCRF import CRF
20
  except ImportError:
21
- # This should be handled in requirements.txt for the Space
22
- print("CRF module not found. Assuming deployment environment will install it.")
23
-
24
-
25
  class CRF:
26
- def __init__(self, *args, **kwargs): pass
 
 
 
 
27
 
28
- def viterbi_decode(self, emissions, mask): return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
29
 
30
  # ========== CONFIG (Must match Training Script) ==========
31
- # NOTE: In a Space, we typically don't use DATA_DIR paths if the files are alongside app.py
32
  MODEL_FILE = "model_CAT.pt"
33
  VOCAB_FILE = "vocabs_CAT.pkl"
34
 
@@ -43,24 +1017,24 @@ BBOX_NORM_CONSTANT = 1000.0
43
  INFERENCE_CHUNK_SIZE = 256
44
 
45
  # ========== LABELS (Must match Training Script) ==========
46
- LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE"]
 
47
  LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
48
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
49
 
50
 
51
  # =========================================================
52
- # 1. Vocab, CharCNNEncoder, and MCQTagger Classes (Copied from your script)
53
  # =========================================================
54
 
55
  class Vocab:
56
- # ... (Your Vocab class implementation)
57
  def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
58
  self.min_freq = min_freq
59
  self.unk_token = unk_token
60
  self.pad_token = pad_token
61
  self.freq = Counter()
62
- self.itos = [] # Index to String
63
- self.stoi = {} # String to Index
64
 
65
  def add_sentence(self, toks):
66
  self.freq.update(toks)
@@ -75,7 +1049,6 @@ class Vocab:
75
  return len(self.itos)
76
 
77
  def __getitem__(self, token: str) -> int:
78
- """Allows lookup using word_vocab[token]. Returns UNK index if token is not found."""
79
  return self.stoi.get(token, self.stoi[self.unk_token])
80
 
81
  def __getstate__(self):
@@ -97,18 +1070,14 @@ class Vocab:
97
 
98
 
99
  def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
100
- """Loads word and character vocabularies from a pickle file and verifies size."""
101
  try:
102
  absolute_path = os.path.abspath(path)
103
- if not os.path.exists(absolute_path):
104
- raise FileNotFoundError(f"Vocab file NOT FOUND at: {absolute_path}")
105
  with open(absolute_path, "rb") as f:
106
  word_vocab, char_vocab = pickle.load(f)
107
  if len(word_vocab) <= 2:
108
- raise IndexError("CRITICAL: Word vocabulary size is too small. Vocab file is invalid.")
109
  return word_vocab, char_vocab
110
- except FileNotFoundError:
111
- raise FileNotFoundError(f"Vocab file not found at {path}. Please run the training script first.")
112
  except Exception as e:
113
  raise RuntimeError(f"Error loading vocabs from {path}: {e}")
114
 
@@ -152,6 +1121,7 @@ class MCQTagger(nn.Module):
152
 
153
  if lengths.max().item() == 0:
154
  B, L = enc_in.size(0), enc_in.size(1)
 
155
  return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
156
 
157
  packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
@@ -162,60 +1132,42 @@ class MCQTagger(nn.Module):
162
 
163
  def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
164
  emissions = self.forward_emissions(words, chars, bboxes, mask)
165
- # We only decode for inference, not calculate loss
166
  return self.crf.viterbi_decode(emissions, mask=mask)
167
 
168
 
169
  # =========================================================
170
- # 2. PDF Processing Functions (Copied from your script)
171
  # =========================================================
172
 
173
  def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
174
- # ... (Your ocr_fallback_page implementation)
175
- """
176
- Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result.
177
- """
178
  try:
179
- # Render page at high resolution (300 DPI equivalent)
180
  pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
181
- if pix.n - pix.alpha > 3: # Handle CMYK
182
  pix = fitz.Pixmap(fitz.csRGB, pix)
183
 
184
  img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
185
 
186
- # Preprocessing for Tesseract (as was in the original code)
187
  img_pil = img_pil.convert('L')
188
  img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
189
  img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
190
 
191
- # Run Tesseract
192
  ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
193
 
194
  ocr_tokens = []
195
  for i in range(len(ocr_data['text'])):
196
  word = ocr_data['text'][i]
197
  conf = ocr_data['conf'][i]
198
- conf = ocr_data['conf'][i]
199
 
200
- # Use only words with reasonable confidence
201
  if word.strip() and int(conf) > 50:
202
- # Get Tesseract's raw pixel bounding box
203
- left = ocr_data['left'][i]
204
- top = ocr_data['top'][i]
205
- width = ocr_data['width'][i]
206
- height = ocr_data['height'][i]
207
-
208
- # Convert pixel bbox back to original PDF coordinate system
209
  scale = page_width / pix.width
210
 
211
  raw_bbox = [
212
- left * scale,
213
- top * scale,
214
- (left + width) * scale,
215
- (top + height) * scale
216
  ]
217
 
218
- # Normalize bbox
219
  normalized_bbox = [
220
  (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
221
  (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
@@ -232,17 +1184,12 @@ def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) ->
232
  return ocr_tokens
233
 
234
  except Exception as e:
235
- # Note: 'page.number' might not be available if not running in a loop context
236
  print(f"OCR fallback failed: {e}")
237
  return []
238
 
239
 
240
  def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
241
- # ... (Your extract_tokens_from_pdf_fitz_with_ocr implementation)
242
- """
243
- Extracts words and their raw bounding boxes using PyMuPDF (fitz) text layer
244
- and falls back to OCR if no text is found.
245
- """
246
  all_tokens = []
247
  try:
248
  doc = fitz.open(pdf_path)
@@ -251,8 +1198,7 @@ def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]
251
  page_width, page_height = page.rect.width, page.rect.height
252
  page_tokens = []
253
 
254
- # 1. Primary Extraction: Use PyMuPDF's word structure (fitz.Page.get_text("words"))
255
- # word_list format: (x0, y0, x1, y1, word, ...)
256
  word_list = page.get_text("words", sort=True)
257
 
258
  if word_list:
@@ -260,7 +1206,6 @@ def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]
260
  word = word_data[4]
261
  raw_bbox = word_data[:4]
262
 
263
- # Normalize bboxes
264
  normalized_bbox = [
265
  (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
266
  (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
@@ -293,10 +1238,7 @@ extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr
293
 
294
  def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
295
  chunk_size: int) -> List[Dict[str, Any]]:
296
- # ... (Your preprocess_and_collate_tokens implementation)
297
- """
298
- Chunks the token list, converts to IDs, and prepares batches for inference. (Unchanged)
299
- """
300
  all_batches = []
301
 
302
  for i in range(0, len(all_tokens), chunk_size):
@@ -330,18 +1272,21 @@ def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab:
330
  "chars": char_pad,
331
  "bboxes": bbox_pad,
332
  "mask": mask,
333
- "original_tokens": chunk # Keep the original data for output formatting
334
  })
335
 
336
  return all_batches
337
 
338
 
339
  # =========================================================
340
- # 3. Model Loading and Caching (Crucial for Gradio performance)
341
  # =========================================================
342
 
343
- # Cache the model and vocabs globally so they are loaded only ONCE when the app starts.
344
- # This avoids reloading the model on every user request, which is vital for speed.
 
 
 
345
  try:
346
  WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE)
347
  MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE)
@@ -349,89 +1294,232 @@ try:
349
  MODEL.eval()
350
  print("βœ… Model and Vocabs loaded successfully (Cached).")
351
  except Exception as e:
352
- MODEL = None
353
  print(f"❌ Initial Model/Vocab Load Failure: {e}")
354
- print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are in the root directory.")
355
 
356
 
357
  # =========================================================
358
- # 4. The Gradio Inference Wrapper Function
359
  # =========================================================
360
 
361
- def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
 
 
 
362
  """
363
- Wraps the entire inference pipeline for the Gradio Interface.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- Args:
366
- pdf_file: The path to the temporary PDF file uploaded by the user (a string).
 
367
 
368
- Returns:
369
- A tuple of (str, List[Dict[str, Any]]): A status message and the raw predictions.
 
370
  """
 
371
  if MODEL is None:
372
  return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", []
373
 
374
  pdf_path = pdf_file
 
375
 
376
  try:
377
- # 1. Extract Tokens
378
  all_tokens = extract_tokens_from_pdf(pdf_path)
379
- except RuntimeError as e:
380
- return f"❌ PDF Processing Error: {e}", []
381
 
382
- if not all_tokens:
383
- return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
384
-
385
- # 2. Preprocess and Batch
386
- batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
387
-
388
- # 3. Run Inference
389
- all_predictions = []
390
- with torch.no_grad():
391
- for batch in batches:
392
- words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
393
-
394
- preds_batch = MODEL(words, chars, bboxes, mask)
395
- predictions = preds_batch[0]
 
 
 
 
 
 
396
 
397
- original_tokens = batch["original_tokens"]
 
398
 
399
- for token_data, pred_idx in zip(original_tokens, predictions):
400
- all_predictions.append({
401
- "word": token_data["word"],
402
- "bbox": token_data["raw_bbox"],
403
- "predicted_label": IDX2LABEL[pred_idx]
404
- })
405
 
406
- status_message = f"βœ… Inference complete. Total tokens predicted: {len(all_predictions)}"
407
 
408
- # Gradio will display the JSON output prettified
409
- return status_message, all_predictions
 
 
410
 
411
 
412
  # =========================================================
413
- # 5. Define and Launch the Gradio Interface
414
  # =========================================================
415
 
416
  if __name__ == "__main__":
417
- title = "MCQ Document Structure Tagger (Bi-LSTM-CRF)"
418
- description = "Upload a PDF document (e.g., an MCQ paper). The model will tokenize the text, run inference to predict BIO-tags (B-QUESTION, I-OPTION, B-ANSWER, etc.) for each word, and return the raw JSON predictions."
419
 
420
- # Define the Gradio Interface
421
  demo = gr.Interface(
422
  fn=gradio_inference_wrapper,
423
- # inputs=gr.File(label="Upload PDF Document", file_types=['.pdf'], type='filepath'),
424
- inputs=gr.File(label="Upload PDF Document"),
425
  outputs=[
426
  gr.Textbox(label="Status Message", interactive=False),
427
- gr.JSON(label="Raw BIO Tagging Predictions (JSON)", show_label=True)
428
  ],
429
  title=title,
430
  description=description,
431
  allow_flagging="never",
432
- # Set a reasonable concurrency limit (number of simultaneous users) for a CPU/small GPU Space
433
  concurrency_limit=2
434
  )
435
 
436
- # Launch the demo (Hugging Face Spaces automatically calls launch() internally)
437
- demo.launch()
 
1
+ # import os
2
+ # import json
3
+ # import pickle
4
+ # from typing import List, Dict, Any, Tuple
5
+ # from collections import Counter
6
+ # import torch
7
+ # import torch.nn as nn
8
+ # import torch.nn.functional as F
9
+ # from tqdm import tqdm
10
+ #
11
+ # # === GRADIO AND DEPENDENCIES ===
12
+ # import gradio as gr
13
+ # import fitz # PyMuPDF
14
+ # import re
15
+ # from PIL import Image, ImageEnhance
16
+ # import pytesseract
17
+ #
18
+ # try:
19
+ # from TorchCRF import CRF
20
+ # except ImportError:
21
+ # # This should be handled in requirements.txt for the Space
22
+ # print("CRF module not found. Assuming deployment environment will install it.")
23
+ #
24
+ #
25
+ # class CRF:
26
+ # def __init__(self, *args, **kwargs): pass
27
+ #
28
+ # def viterbi_decode(self, emissions, mask): return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
29
+ #
30
+ # # ========== CONFIG (Must match Training Script) ==========
31
+ # # NOTE: In a Space, we typically don't use DATA_DIR paths if the files are alongside app.py
32
+ # MODEL_FILE = "model_CAT.pt"
33
+ # VOCAB_FILE = "vocabs_CAT.pkl"
34
+ #
35
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # MAX_CHAR_LEN = 16
37
+ # EMBED_DIM = 100
38
+ # CHAR_EMBED_DIM = 30
39
+ # CHAR_CNN_OUT = 30
40
+ # BBOX_DIM = 100
41
+ # HIDDEN_SIZE = 512
42
+ # BBOX_NORM_CONSTANT = 1000.0
43
+ # INFERENCE_CHUNK_SIZE = 256
44
+ #
45
+ # # ========== LABELS (Must match Training Script) ==========
46
+ # LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE"]
47
+ # LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
48
+ # IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
49
+ #
50
+ #
51
+ # # =========================================================
52
+ # # 1. Vocab, CharCNNEncoder, and MCQTagger Classes (Copied from your script)
53
+ # # =========================================================
54
+ #
55
+ # class Vocab:
56
+ # # ... (Your Vocab class implementation)
57
+ # def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
58
+ # self.min_freq = min_freq
59
+ # self.unk_token = unk_token
60
+ # self.pad_token = pad_token
61
+ # self.freq = Counter()
62
+ # self.itos = [] # Index to String
63
+ # self.stoi = {} # String to Index
64
+ #
65
+ # def add_sentence(self, toks):
66
+ # self.freq.update(toks)
67
+ #
68
+ # def build(self):
69
+ # items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
70
+ # items = [self.pad_token, self.unk_token] + sorted(items)
71
+ # self.itos = items
72
+ # self.stoi = {s: i for i, s in enumerate(self.itos)}
73
+ #
74
+ # def __len__(self):
75
+ # return len(self.itos)
76
+ #
77
+ # def __getitem__(self, token: str) -> int:
78
+ # """Allows lookup using word_vocab[token]. Returns UNK index if token is not found."""
79
+ # return self.stoi.get(token, self.stoi[self.unk_token])
80
+ #
81
+ # def __getstate__(self):
82
+ # return {
83
+ # 'min_freq': self.min_freq,
84
+ # 'unk_token': self.unk_token,
85
+ # 'pad_token': self.pad_token,
86
+ # 'itos': self.itos,
87
+ # 'stoi': self.stoi,
88
+ # }
89
+ #
90
+ # def __setstate__(self, state):
91
+ # self.min_freq = state['min_freq']
92
+ # self.unk_token = state['unk_token']
93
+ # self.pad_token = state['pad_token']
94
+ # self.itos = state['itos']
95
+ # self.stoi = state['stoi']
96
+ # self.freq = Counter()
97
+ #
98
+ #
99
+ # def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
100
+ # """Loads word and character vocabularies from a pickle file and verifies size."""
101
+ # try:
102
+ # absolute_path = os.path.abspath(path)
103
+ # if not os.path.exists(absolute_path):
104
+ # raise FileNotFoundError(f"Vocab file NOT FOUND at: {absolute_path}")
105
+ # with open(absolute_path, "rb") as f:
106
+ # word_vocab, char_vocab = pickle.load(f)
107
+ # if len(word_vocab) <= 2:
108
+ # raise IndexError("CRITICAL: Word vocabulary size is too small. Vocab file is invalid.")
109
+ # return word_vocab, char_vocab
110
+ # except FileNotFoundError:
111
+ # raise FileNotFoundError(f"Vocab file not found at {path}. Please run the training script first.")
112
+ # except Exception as e:
113
+ # raise RuntimeError(f"Error loading vocabs from {path}: {e}")
114
+ #
115
+ #
116
+ # class CharCNNEncoder(nn.Module):
117
+ # def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3, 4, 5)):
118
+ # super().__init__()
119
+ # self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
120
+ # convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes]
121
+ # self.convs = nn.ModuleList(convs)
122
+ # self.out_dim = out_dim * len(convs)
123
+ #
124
+ # def forward(self, char_ids):
125
+ # B, L, C = char_ids.size()
126
+ # emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
127
+ # outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
128
+ # res = torch.cat(outs, dim=1)
129
+ # return res.view(B, L, -1)
130
+ #
131
+ #
132
+ # class MCQTagger(nn.Module):
133
+ # def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM):
134
+ # super().__init__()
135
+ # self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
136
+ # self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
137
+ # self.bbox_proj = nn.Linear(4, bbox_dim)
138
+ # in_dim = EMBED_DIM + self.char_enc.out_dim + bbox_dim
139
+ #
140
+ # self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3)
141
+ # self.ff = nn.Linear(HIDDEN_SIZE, n_labels)
142
+ # self.crf = CRF(n_labels)
143
+ # self.dropout = nn.Dropout(p=0.5)
144
+ #
145
+ # def forward_emissions(self, words, chars, bboxes, mask):
146
+ # wemb = self.word_emb(words)
147
+ # cenc = self.char_enc(chars)
148
+ # benc = self.bbox_proj(bboxes)
149
+ # enc_in = torch.cat([wemb, cenc, benc], dim=-1)
150
+ # enc_in = self.dropout(enc_in)
151
+ # lengths = mask.sum(dim=1).cpu()
152
+ #
153
+ # if lengths.max().item() == 0:
154
+ # B, L = enc_in.size(0), enc_in.size(1)
155
+ # return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
156
+ #
157
+ # packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
158
+ # packed_out, _ = self.bilstm(packed_in)
159
+ # padded_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
160
+ #
161
+ # return self.ff(padded_out)
162
+ #
163
+ # def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
164
+ # emissions = self.forward_emissions(words, chars, bboxes, mask)
165
+ # # We only decode for inference, not calculate loss
166
+ # return self.crf.viterbi_decode(emissions, mask=mask)
167
+ #
168
+ #
169
+ # # =========================================================
170
+ # # 2. PDF Processing Functions (Copied from your script)
171
+ # # =========================================================
172
+ #
173
+ # def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
174
+ # # ... (Your ocr_fallback_page implementation)
175
+ # """
176
+ # Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result.
177
+ # """
178
+ # try:
179
+ # # Render page at high resolution (300 DPI equivalent)
180
+ # pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
181
+ # if pix.n - pix.alpha > 3: # Handle CMYK
182
+ # pix = fitz.Pixmap(fitz.csRGB, pix)
183
+ #
184
+ # img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
185
+ #
186
+ # # Preprocessing for Tesseract (as was in the original code)
187
+ # img_pil = img_pil.convert('L')
188
+ # img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
189
+ # img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
190
+ #
191
+ # # Run Tesseract
192
+ # ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
193
+ #
194
+ # ocr_tokens = []
195
+ # for i in range(len(ocr_data['text'])):
196
+ # word = ocr_data['text'][i]
197
+ # conf = ocr_data['conf'][i]
198
+ # conf = ocr_data['conf'][i]
199
+ #
200
+ # # Use only words with reasonable confidence
201
+ # if word.strip() and int(conf) > 50:
202
+ # # Get Tesseract's raw pixel bounding box
203
+ # left = ocr_data['left'][i]
204
+ # top = ocr_data['top'][i]
205
+ # width = ocr_data['width'][i]
206
+ # height = ocr_data['height'][i]
207
+ #
208
+ # # Convert pixel bbox back to original PDF coordinate system
209
+ # scale = page_width / pix.width
210
+ #
211
+ # raw_bbox = [
212
+ # left * scale,
213
+ # top * scale,
214
+ # (left + width) * scale,
215
+ # (top + height) * scale
216
+ # ]
217
+ #
218
+ # # Normalize bbox
219
+ # normalized_bbox = [
220
+ # (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
221
+ # (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
222
+ # (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
223
+ # (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
224
+ # ]
225
+ #
226
+ # ocr_tokens.append({
227
+ # "word": word,
228
+ # "raw_bbox": [int(b) for b in raw_bbox],
229
+ # "normalized_bbox": [int(b) for b in normalized_bbox]
230
+ # })
231
+ #
232
+ # return ocr_tokens
233
+ #
234
+ # except Exception as e:
235
+ # # Note: 'page.number' might not be available if not running in a loop context
236
+ # print(f"OCR fallback failed: {e}")
237
+ # return []
238
+ #
239
+ #
240
+ # def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
241
+ # # ... (Your extract_tokens_from_pdf_fitz_with_ocr implementation)
242
+ # """
243
+ # Extracts words and their raw bounding boxes using PyMuPDF (fitz) text layer
244
+ # and falls back to OCR if no text is found.
245
+ # """
246
+ # all_tokens = []
247
+ # try:
248
+ # doc = fitz.open(pdf_path)
249
+ # for page_num in tqdm(range(len(doc)), desc="PDF Page Processing"):
250
+ # page = doc.load_page(page_num)
251
+ # page_width, page_height = page.rect.width, page.rect.height
252
+ # page_tokens = []
253
+ #
254
+ # # 1. Primary Extraction: Use PyMuPDF's word structure (fitz.Page.get_text("words"))
255
+ # # word_list format: (x0, y0, x1, y1, word, ...)
256
+ # word_list = page.get_text("words", sort=True)
257
+ #
258
+ # if word_list:
259
+ # for word_data in word_list:
260
+ # word = word_data[4]
261
+ # raw_bbox = word_data[:4]
262
+ #
263
+ # # Normalize bboxes
264
+ # normalized_bbox = [
265
+ # (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
266
+ # (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
267
+ # (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
268
+ # (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
269
+ # ]
270
+ #
271
+ # page_tokens.append({
272
+ # "word": word,
273
+ # "raw_bbox": [int(b) for b in raw_bbox],
274
+ # "normalized_bbox": [int(b) for b in normalized_bbox]
275
+ # })
276
+ #
277
+ # # 2. OCR Fallback
278
+ # if not page_tokens:
279
+ # print(f" (Page {page_num + 1}) No text layer found. Running OCR...")
280
+ # page_tokens = ocr_fallback_page(page, page_width, page_height)
281
+ #
282
+ # all_tokens.extend(page_tokens)
283
+ #
284
+ # doc.close()
285
+ # except Exception as e:
286
+ # raise RuntimeError(f"Error opening or processing PDF with fitz/OCR: {e}")
287
+ #
288
+ # return all_tokens
289
+ #
290
+ #
291
+ # extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr
292
+ #
293
+ #
294
+ # def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
295
+ # chunk_size: int) -> List[Dict[str, Any]]:
296
+ # # ... (Your preprocess_and_collate_tokens implementation)
297
+ # """
298
+ # Chunks the token list, converts to IDs, and prepares batches for inference. (Unchanged)
299
+ # """
300
+ # all_batches = []
301
+ #
302
+ # for i in range(0, len(all_tokens), chunk_size):
303
+ # chunk = all_tokens[i:i + chunk_size]
304
+ # if not chunk: continue
305
+ #
306
+ # words = [t["word"] for t in chunk]
307
+ # bboxes_norm = [t["normalized_bbox"] for t in chunk]
308
+ #
309
+ # # Convert to IDs
310
+ # word_ids = [word_vocab[w] for w in words]
311
+ #
312
+ # char_ids = []
313
+ # for w in words:
314
+ # chs = [char_vocab[ch] for ch in w[:MAX_CHAR_LEN]]
315
+ # if len(chs) < MAX_CHAR_LEN:
316
+ # pad_index = char_vocab.stoi.get(char_vocab.pad_token, 0)
317
+ # chs += [pad_index] * (MAX_CHAR_LEN - len(chs))
318
+ # char_ids.append(chs)
319
+ #
320
+ # # Create padded tensors (using single-sample batches)
321
+ # word_pad = torch.LongTensor([word_ids]).to(DEVICE)
322
+ # char_pad = torch.LongTensor([char_ids]).to(DEVICE)
323
+ #
324
+ # # Final normalization to [0, 1] range before feeding to the model
325
+ # bbox_pad = torch.FloatTensor([bboxes_norm]).to(DEVICE) / BBOX_NORM_CONSTANT
326
+ # mask = torch.ones(word_pad.size(), dtype=torch.bool).to(DEVICE)
327
+ #
328
+ # all_batches.append({
329
+ # "words": word_pad,
330
+ # "chars": char_pad,
331
+ # "bboxes": bbox_pad,
332
+ # "mask": mask,
333
+ # "original_tokens": chunk # Keep the original data for output formatting
334
+ # })
335
+ #
336
+ # return all_batches
337
+ #
338
+ #
339
+ # # =========================================================
340
+ # # 3. Model Loading and Caching (Crucial for Gradio performance)
341
+ # # =========================================================
342
+ #
343
+ # # Cache the model and vocabs globally so they are loaded only ONCE when the app starts.
344
+ # # This avoids reloading the model on every user request, which is vital for speed.
345
+ # try:
346
+ # WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE)
347
+ # MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE)
348
+ # MODEL.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE))
349
+ # MODEL.eval()
350
+ # print("βœ… Model and Vocabs loaded successfully (Cached).")
351
+ # except Exception as e:
352
+ # MODEL = None
353
+ # print(f"❌ Initial Model/Vocab Load Failure: {e}")
354
+ # print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are in the root directory.")
355
+ #
356
+ #
357
+ # # =========================================================
358
+ # # 4. The Gradio Inference Wrapper Function
359
+ # # =========================================================
360
+ #
361
+ # def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
362
+ # """
363
+ # Wraps the entire inference pipeline for the Gradio Interface.
364
+ #
365
+ # Args:
366
+ # pdf_file: The path to the temporary PDF file uploaded by the user (a string).
367
+ #
368
+ # Returns:
369
+ # A tuple of (str, List[Dict[str, Any]]): A status message and the raw predictions.
370
+ # """
371
+ # if MODEL is None:
372
+ # return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", []
373
+ #
374
+ # pdf_path = pdf_file
375
+ #
376
+ # try:
377
+ # # 1. Extract Tokens
378
+ # all_tokens = extract_tokens_from_pdf(pdf_path)
379
+ # except RuntimeError as e:
380
+ # return f"❌ PDF Processing Error: {e}", []
381
+ #
382
+ # if not all_tokens:
383
+ # return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
384
+ #
385
+ # # 2. Preprocess and Batch
386
+ # batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
387
+ #
388
+ # # 3. Run Inference
389
+ # all_predictions = []
390
+ # with torch.no_grad():
391
+ # for batch in batches:
392
+ # words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
393
+ #
394
+ # preds_batch = MODEL(words, chars, bboxes, mask)
395
+ # predictions = preds_batch[0]
396
+ #
397
+ # original_tokens = batch["original_tokens"]
398
+ #
399
+ # for token_data, pred_idx in zip(original_tokens, predictions):
400
+ # all_predictions.append({
401
+ # "word": token_data["word"],
402
+ # "bbox": token_data["raw_bbox"],
403
+ # "predicted_label": IDX2LABEL[pred_idx]
404
+ # })
405
+ #
406
+ # status_message = f"βœ… Inference complete. Total tokens predicted: {len(all_predictions)}"
407
+ #
408
+ # # Gradio will display the JSON output prettified
409
+ # return status_message, all_predictions
410
+ #
411
+ #
412
+ # # =========================================================
413
+ # # 5. Define and Launch the Gradio Interface
414
+ # # =========================================================
415
+ #
416
+ # if __name__ == "__main__":
417
+ # title = "MCQ Document Structure Tagger (Bi-LSTM-CRF)"
418
+ # description = "Upload a PDF document (e.g., an MCQ paper). The model will tokenize the text, run inference to predict BIO-tags (B-QUESTION, I-OPTION, B-ANSWER, etc.) for each word, and return the raw JSON predictions."
419
+ #
420
+ # # Define the Gradio Interface
421
+ # demo = gr.Interface(
422
+ # fn=gradio_inference_wrapper,
423
+ # # inputs=gr.File(label="Upload PDF Document", file_types=['.pdf'], type='filepath'),
424
+ # inputs=gr.File(label="Upload PDF Document"),
425
+ # outputs=[
426
+ # gr.Textbox(label="Status Message", interactive=False),
427
+ # gr.JSON(label="Raw BIO Tagging Predictions (JSON)", show_label=True)
428
+ # ],
429
+ # title=title,
430
+ # description=description,
431
+ # allow_flagging="never",
432
+ # # Set a reasonable concurrency limit (number of simultaneous users) for a CPU/small GPU Space
433
+ # concurrency_limit=2
434
+ # )
435
+ #
436
+ # # Launch the demo (Hugging Face Spaces automatically calls launch() internally)
437
+ # demo.launch()
438
+
439
+ #
440
+ # import os
441
+ # import json
442
+ # import pickle
443
+ # from typing import List, Dict, Any, Tuple
444
+ # from collections import Counter
445
+ # import torch
446
+ # import torch.nn as nn
447
+ # import torch.nn.functional as F
448
+ # import re
449
+ # from tqdm import tqdm
450
+ #
451
+ # # === GRADIO AND DEPENDENCIES ===
452
+ # import gradio as gr
453
+ # import fitz # PyMuPDF
454
+ # from PIL import Image, ImageEnhance
455
+ # import pytesseract
456
+ #
457
+ # try:
458
+ # from TorchCRF import CRF
459
+ # except ImportError:
460
+ # # Placeholder for environments where it's not yet installed
461
+ # class CRF:
462
+ # def __init__(self, *args, **kwargs): pass
463
+ #
464
+ # def viterbi_decode(self, emissions, mask): return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
465
+ #
466
+ # # ========== CONFIG (Must match Training Script) ==========
467
+ # MODEL_FILE = "model_CAT.pt"
468
+ # VOCAB_FILE = "vocabs_CAT.pkl"
469
+ #
470
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
471
+ # MAX_CHAR_LEN = 16
472
+ # EMBED_DIM = 100
473
+ # CHAR_EMBED_DIM = 30
474
+ # CHAR_CNN_OUT = 30
475
+ # BBOX_DIM = 100
476
+ # HIDDEN_SIZE = 512
477
+ # BBOX_NORM_CONSTANT = 1000.0
478
+ # INFERENCE_CHUNK_SIZE = 256
479
+ #
480
+ # # ========== LABELS (Must match Training Script) ==========
481
+ # # NOTE: Added B/I-PASSAGE for the new structuring function
482
+ # LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE",
483
+ # "B-PASSAGE", "I-PASSAGE"]
484
+ # LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
485
+ # IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
486
+ #
487
+ #
488
+ # # =========================================================
489
+ # # 1. Core Classes (Vocab, CharCNNEncoder, MCQTagger)
490
+ # # (Your classes are retained here)
491
+ # # =========================================================
492
+ #
493
+ # class Vocab:
494
+ # def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
495
+ # self.min_freq = min_freq
496
+ # self.unk_token = unk_token
497
+ # self.pad_token = pad_token
498
+ # self.freq = Counter()
499
+ # self.itos = []
500
+ # self.stoi = {}
501
+ #
502
+ # def add_sentence(self, toks):
503
+ # self.freq.update(toks)
504
+ #
505
+ # def build(self):
506
+ # items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
507
+ # items = [self.pad_token, self.unk_token] + sorted(items)
508
+ # self.itos = items
509
+ # self.stoi = {s: i for i, s in enumerate(self.itos)}
510
+ #
511
+ # def __len__(self):
512
+ # return len(self.itos)
513
+ #
514
+ # def __getitem__(self, token: str) -> int:
515
+ # return self.stoi.get(token, self.stoi[self.unk_token])
516
+ #
517
+ # def __getstate__(self):
518
+ # return {
519
+ # 'min_freq': self.min_freq,
520
+ # 'unk_token': self.unk_token,
521
+ # 'pad_token': self.pad_token,
522
+ # 'itos': self.itos,
523
+ # 'stoi': self.stoi,
524
+ # }
525
+ #
526
+ # def __setstate__(self, state):
527
+ # self.min_freq = state['min_freq']
528
+ # self.unk_token = state['unk_token']
529
+ # self.pad_token = state['pad_token']
530
+ # self.itos = state['itos']
531
+ # self.stoi = state['stoi']
532
+ # self.freq = Counter()
533
+ #
534
+ #
535
+ # def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
536
+ # """Loads word and character vocabularies."""
537
+ # try:
538
+ # absolute_path = os.path.abspath(path)
539
+ # with open(absolute_path, "rb") as f:
540
+ # word_vocab, char_vocab = pickle.load(f)
541
+ # if len(word_vocab) <= 2:
542
+ # raise IndexError("CRITICAL: Word vocabulary size is too small.")
543
+ # return word_vocab, char_vocab
544
+ # except Exception as e:
545
+ # raise RuntimeError(f"Error loading vocabs from {path}: {e}")
546
+ #
547
+ #
548
+ # class CharCNNEncoder(nn.Module):
549
+ # def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3, 4, 5)):
550
+ # super().__init__()
551
+ # self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
552
+ # convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes]
553
+ # self.convs = nn.ModuleList(convs)
554
+ # self.out_dim = out_dim * len(convs)
555
+ #
556
+ # def forward(self, char_ids):
557
+ # B, L, C = char_ids.size()
558
+ # emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
559
+ # outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
560
+ # res = torch.cat(outs, dim=1)
561
+ # return res.view(B, L, -1)
562
+ #
563
+ #
564
+ # class MCQTagger(nn.Module):
565
+ # def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM):
566
+ # super().__init__()
567
+ # self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
568
+ # self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
569
+ # self.bbox_proj = nn.Linear(4, bbox_dim)
570
+ # in_dim = EMBED_DIM + self.char_enc.out_dim + bbox_dim
571
+ #
572
+ # self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3)
573
+ # self.ff = nn.Linear(HIDDEN_SIZE, n_labels)
574
+ # self.crf = CRF(n_labels)
575
+ # self.dropout = nn.Dropout(p=0.5)
576
+ #
577
+ # def forward_emissions(self, words, chars, bboxes, mask):
578
+ # wemb = self.word_emb(words)
579
+ # cenc = self.char_enc(chars)
580
+ # benc = self.bbox_proj(bboxes)
581
+ # enc_in = torch.cat([wemb, cenc, benc], dim=-1)
582
+ # enc_in = self.dropout(enc_in)
583
+ # lengths = mask.sum(dim=1).cpu()
584
+ #
585
+ # if lengths.max().item() == 0:
586
+ # B, L = enc_in.size(0), enc_in.size(1)
587
+ # return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
588
+ #
589
+ # packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
590
+ # packed_out, _ = self.bilstm(packed_in)
591
+ # padded_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
592
+ #
593
+ # return self.ff(padded_out)
594
+ #
595
+ # def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
596
+ # emissions = self.forward_emissions(words, chars, bboxes, mask)
597
+ # return self.crf.viterbi_decode(emissions, mask=mask)
598
+ #
599
+ #
600
+ # # =========================================================
601
+ # # 2. PDF Processing Functions
602
+ # # (Your PDF functions are retained here)
603
+ # # =========================================================
604
+ #
605
+ # def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
606
+ # """Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result."""
607
+ # try:
608
+ # pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
609
+ # if pix.n - pix.alpha > 3:
610
+ # pix = fitz.Pixmap(fitz.csRGB, pix)
611
+ #
612
+ # img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
613
+ #
614
+ # # Preprocessing
615
+ # img_pil = img_pil.convert('L')
616
+ # img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
617
+ # img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
618
+ #
619
+ # ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
620
+ #
621
+ # ocr_tokens = []
622
+ # for i in range(len(ocr_data['text'])):
623
+ # word = ocr_data['text'][i]
624
+ # conf = ocr_data['conf'][i]
625
+ #
626
+ # if word.strip() and int(conf) > 50:
627
+ # left, top, width, height = (ocr_data[k][i] for k in ['left', 'top', 'width', 'height'])
628
+ # scale = page_width / pix.width
629
+ #
630
+ # raw_bbox = [
631
+ # left * scale, top * scale, (left + width) * scale, (top + height) * scale
632
+ # ]
633
+ #
634
+ # normalized_bbox = [
635
+ # (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
636
+ # (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
637
+ # (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
638
+ # (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
639
+ # ]
640
+ #
641
+ # ocr_tokens.append({
642
+ # "word": word,
643
+ # "raw_bbox": [int(b) for b in raw_bbox],
644
+ # "normalized_bbox": [int(b) for b in normalized_bbox]
645
+ # })
646
+ #
647
+ # return ocr_tokens
648
+ #
649
+ # except Exception as e:
650
+ # print(f"OCR fallback failed: {e}")
651
+ # return []
652
+ #
653
+ #
654
+ # def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
655
+ # """Extracts words and bboxes using PyMuPDF text layer and falls back to OCR."""
656
+ # all_tokens = []
657
+ # try:
658
+ # doc = fitz.open(pdf_path)
659
+ # for page_num in tqdm(range(len(doc)), desc="PDF Page Processing"):
660
+ # page = doc.load_page(page_num)
661
+ # page_width, page_height = page.rect.width, page.rect.height
662
+ # page_tokens = []
663
+ #
664
+ # # 1. Primary Extraction: PyMuPDF's word structure
665
+ # word_list = page.get_text("words", sort=True)
666
+ #
667
+ # if word_list:
668
+ # for word_data in word_list:
669
+ # word = word_data[4]
670
+ # raw_bbox = word_data[:4]
671
+ #
672
+ # normalized_bbox = [
673
+ # (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
674
+ # (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
675
+ # (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
676
+ # (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
677
+ # ]
678
+ #
679
+ # page_tokens.append({
680
+ # "word": word,
681
+ # "raw_bbox": [int(b) for b in raw_bbox],
682
+ # "normalized_bbox": [int(b) for b in normalized_bbox]
683
+ # })
684
+ #
685
+ # # 2. OCR Fallback
686
+ # if not page_tokens:
687
+ # print(f" (Page {page_num + 1}) No text layer found. Running OCR...")
688
+ # page_tokens = ocr_fallback_page(page, page_width, page_height)
689
+ #
690
+ # all_tokens.extend(page_tokens)
691
+ #
692
+ # doc.close()
693
+ # except Exception as e:
694
+ # raise RuntimeError(f"Error opening or processing PDF with fitz/OCR: {e}")
695
+ #
696
+ # return all_tokens
697
+ #
698
+ #
699
+ # extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr
700
+ #
701
+ #
702
+ # def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
703
+ # chunk_size: int) -> List[Dict[str, Any]]:
704
+ # """Chunks the token list, converts to IDs, and prepares batches for inference."""
705
+ # all_batches = []
706
+ #
707
+ # for i in range(0, len(all_tokens), chunk_size):
708
+ # chunk = all_tokens[i:i + chunk_size]
709
+ # if not chunk: continue
710
+ #
711
+ # words = [t["word"] for t in chunk]
712
+ # bboxes_norm = [t["normalized_bbox"] for t in chunk]
713
+ #
714
+ # # Convert to IDs
715
+ # word_ids = [word_vocab[w] for w in words]
716
+ #
717
+ # char_ids = []
718
+ # for w in words:
719
+ # chs = [char_vocab[ch] for ch in w[:MAX_CHAR_LEN]]
720
+ # if len(chs) < MAX_CHAR_LEN:
721
+ # pad_index = char_vocab.stoi.get(char_vocab.pad_token, 0)
722
+ # chs += [pad_index] * (MAX_CHAR_LEN - len(chs))
723
+ # char_ids.append(chs)
724
+ #
725
+ # # Create padded tensors (using single-sample batches)
726
+ # word_pad = torch.LongTensor([word_ids]).to(DEVICE)
727
+ # char_pad = torch.LongTensor([char_ids]).to(DEVICE)
728
+ #
729
+ # # Final normalization to [0, 1] range before feeding to the model
730
+ # bbox_pad = torch.FloatTensor([bboxes_norm]).to(DEVICE) / BBOX_NORM_CONSTANT
731
+ # mask = torch.ones(word_pad.size(), dtype=torch.bool).to(DEVICE)
732
+ #
733
+ # all_batches.append({
734
+ # "words": word_pad,
735
+ # "chars": char_pad,
736
+ # "bboxes": bbox_pad,
737
+ # "mask": mask,
738
+ # "original_tokens": chunk
739
+ # })
740
+ #
741
+ # return all_batches
742
+ #
743
+ #
744
+ # # =========================================================
745
+ # # 3. Structuring Logic (Adapted from your second script)
746
+ # # =========================================================
747
+ #
748
+ # def finalize_passage_to_item(item, passage_buffer):
749
+ # """Adds passage text to the current item and clears the buffer."""
750
+ # if passage_buffer:
751
+ # # Use a more careful cleaning, focusing on space reduction
752
+ # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
753
+ # if item.get('passage'):
754
+ # item['passage'] += ' ' + passage_text
755
+ # else:
756
+ # item['passage'] = passage_text
757
+ # passage_buffer.clear()
758
+ # return item
759
+ #
760
+ #
761
+ # def convert_bio_to_structured_json_strict(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
762
+ # """
763
+ # Converts a list of {word, predicted_label} tokens into structured MCQ JSON format.
764
+ # This function is adapted to work directly with the list of predictions (in-memory).
765
+ # """
766
+ # structured_data = []
767
+ # current_item = None
768
+ # current_option_key = None
769
+ # current_passage_buffer = []
770
+ # current_text_buffer = []
771
+ #
772
+ # first_question_started = False
773
+ # last_entity_type = None
774
+ #
775
+ # for item in predictions:
776
+ # word = item['word']
777
+ # label = item['predicted_label']
778
+ # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
779
+ #
780
+ # # Always append word to the total text buffer
781
+ # current_text_buffer.append(word)
782
+ #
783
+ # is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE')
784
+ #
785
+ # # --- BEFORE FIRST QUESTION/METADATA HANDLING ---
786
+ # if not first_question_started and label != 'B-QUESTION' and not is_passage_label:
787
+ # continue
788
+ #
789
+ # # --- PASSAGE HANDLING (Before question start) ---
790
+ # if not first_question_started and is_passage_label:
791
+ # if label == 'B-PASSAGE' or (label == 'I-PASSAGE' and last_entity_type == 'PASSAGE'):
792
+ # current_passage_buffer.append(word)
793
+ # last_entity_type = 'PASSAGE'
794
+ # continue
795
+ #
796
+ # # --- NEW QUESTION START (B-QUESTION) ---
797
+ # if label == 'B-QUESTION':
798
+ #
799
+ # # 1. Capture leading text/passage as METADATA (for the very first block)
800
+ # if not first_question_started:
801
+ # header_text = ' '.join(current_text_buffer[:-1]).strip()
802
+ # if header_text or current_passage_buffer:
803
+ # metadata_item = {'type': 'METADATA'}
804
+ # metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
805
+ # if header_text:
806
+ # metadata_item['text'] = header_text
807
+ # structured_data.append(metadata_item)
808
+ #
809
+ # first_question_started = True
810
+ # current_text_buffer = [word]
811
+ #
812
+ # # 2. Save previous question block (for subsequent questions)
813
+ # elif current_item is not None:
814
+ # current_item = finalize_passage_to_item(current_item, current_passage_buffer)
815
+ # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
816
+ # structured_data.append(current_item)
817
+ # current_text_buffer = [word]
818
+ #
819
+ # # 3. Initialize new question
820
+ # current_item = {
821
+ # 'type': 'MCQ', # Explicitly define the type for the final output
822
+ # 'question': word,
823
+ # 'options_text': {},
824
+ # 'answer': '',
825
+ # 'text': '' # The raw text span of the item
826
+ # }
827
+ # current_option_key = None
828
+ # last_entity_type = 'QUESTION'
829
+ # continue
830
+ #
831
+ # # --- IF INSIDE A QUESTION BLOCK ---
832
+ # if current_item is not None:
833
+ #
834
+ # if label.startswith('B-'):
835
+ # last_entity_type = entity_type
836
+ #
837
+ # if entity_type == 'PASSAGE':
838
+ # finalize_passage_to_item(current_item, current_passage_buffer)
839
+ # current_passage_buffer.append(word)
840
+ # elif entity_type == 'OPTION':
841
+ # current_option_key = word
842
+ # current_item['options_text'][current_option_key] = word
843
+ # current_passage_buffer = []
844
+ # elif entity_type == 'ANSWER':
845
+ # current_item['answer'] = word
846
+ # current_option_key = None
847
+ # current_passage_buffer = []
848
+ # elif entity_type == 'QUESTION':
849
+ # current_item['question'] += f' {word}'
850
+ # current_passage_buffer = []
851
+ #
852
+ # elif label.startswith('I-'):
853
+ # if entity_type == 'QUESTION' and last_entity_type == 'QUESTION':
854
+ # current_item['question'] += f' {word}'
855
+ # elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None:
856
+ # current_item['options_text'][current_option_key] += f' {word}'
857
+ # elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER':
858
+ # current_item['answer'] += f' {word}'
859
+ # elif entity_type == 'PASSAGE' and last_entity_type == 'PASSAGE':
860
+ # current_passage_buffer.append(word)
861
+ #
862
+ # # O-tokens are ignored for entity building but collected in current_text_buffer.
863
+ # elif label == 'O':
864
+ # pass
865
+ #
866
+ # # --- Finalize last item ---
867
+ # if current_item is not None:
868
+ # current_item = finalize_passage_to_item(current_item, current_passage_buffer)
869
+ # current_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
870
+ # structured_data.append(current_item)
871
+ # elif not structured_data and current_passage_buffer:
872
+ # # Case: Only passage/metadata was present in the whole document
873
+ # metadata_item = {'type': 'METADATA'}
874
+ # metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
875
+ # metadata_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
876
+ # structured_data.append(metadata_item)
877
+ #
878
+ # # --- FINAL CLEANUP ---
879
+ # for item in structured_data:
880
+ # # Final cleanup for all text fields
881
+ # item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
882
+ # if 'passage' in item:
883
+ # item['passage'] = re.sub(r'\s{2,}', ' ', item['passage']).strip()
884
+ # if not item['passage']:
885
+ # del item['passage']
886
+ # if 'question' in item:
887
+ # item['question'] = re.sub(r'\s{2,}', ' ', item['question']).strip()
888
+ # if 'answer' in item:
889
+ # item['answer'] = re.sub(r'\s{2,}', ' ', item['answer']).strip()
890
+ # if 'options_text' in item:
891
+ # for k, v in item['options_text'].items():
892
+ # item['options_text'][k] = re.sub(r'\s{2,}', ' ', v).strip()
893
+ #
894
+ # return structured_data
895
+ #
896
+ #
897
+ # # =========================================================
898
+ # # 4. Updated Gradio Inference Wrapper Function
899
+ # # =========================================================
900
+ #
901
+ # def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
902
+ # """
903
+ # Wraps the entire two-stage pipeline: (1) Tagging -> (2) Structuring.
904
+ # """
905
+ # if MODEL is None:
906
+ # return "❌ ERROR: Model failed to load on startup.", []
907
+ #
908
+ # pdf_path = pdf_file
909
+ # raw_predictions = []
910
+ #
911
+ # try:
912
+ # # 1. Stage 1: PDF Processing and BIO Tagging (Unchanged from before)
913
+ # all_tokens = extract_tokens_from_pdf(pdf_path)
914
+ #
915
+ # if not all_tokens:
916
+ # return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
917
+ #
918
+ # batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
919
+ #
920
+ # with torch.no_grad():
921
+ # for batch in batches:
922
+ # words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
923
+ # preds_batch = MODEL(words, chars, bboxes, mask)
924
+ # predictions = preds_batch[0]
925
+ # original_tokens = batch["original_tokens"]
926
+ #
927
+ # for token_data, pred_idx in zip(original_tokens, predictions):
928
+ # raw_predictions.append({
929
+ # "word": token_data["word"],
930
+ # "bbox": token_data["raw_bbox"],
931
+ # "predicted_label": IDX2LABEL[pred_idx]
932
+ # })
933
+ #
934
+ # # 2. Stage 2: Structured JSON Conversion (The NEW step)
935
+ # structured_output = convert_bio_to_structured_json_strict(raw_predictions)
936
+ #
937
+ # status_message = f"βœ… Conversion complete. Found {len([i for i in structured_output if i.get('type') == 'MCQ'])} MCQ items."
938
+ #
939
+ # # Return the final structured output
940
+ # return status_message, structured_output
941
+ #
942
+ # except RuntimeError as e:
943
+ # return f"❌ PDF Processing Error: {e}", []
944
+ # except Exception as e:
945
+ # # Catch any unexpected errors during inference or structuring
946
+ # return f"❌ An unexpected processing error occurred: {e}", []
947
+ #
948
+ #
949
+ # # =========================================================
950
+ # # 5. Define and Launch the Gradio Interface
951
+ # # (Output changed to only show the final structured JSON)
952
+ # # =========================================================
953
+ #
954
+ # if __name__ == "__main__":
955
+ # title = "MCQ Document Structure Tagger (Bi-LSTM-CRF) - Structured Output"
956
+ # description = "Upload a PDF document. The system processes it in two stages: 1) BIO-Tagging for structural elements (Question, Option, Answer, Passage) and 2) Converting those tags into a clean, structured JSON list of MCQ items."
957
+ #
958
+ # demo = gr.Interface(
959
+ # fn=gradio_inference_wrapper,
960
+ # inputs=gr.File(label="Upload PDF Document", file_types=['pdf']),
961
+ # outputs=[
962
+ # gr.Textbox(label="Status Message", interactive=False),
963
+ # gr.JSON(label="Structured MCQ JSON Output", show_label=True)
964
+ # ],
965
+ # title=title,
966
+ # description=description,
967
+ # allow_flagging="never",
968
+ # concurrency_limit=2
969
+ # )
970
+ #
971
+ # demo.launch()
972
+
973
+
974
+
975
  import os
976
  import json
977
  import pickle
 
980
  import torch
981
  import torch.nn as nn
982
  import torch.nn.functional as F
983
+ import re
984
  from tqdm import tqdm
985
 
986
  # === GRADIO AND DEPENDENCIES ===
987
  import gradio as gr
988
  import fitz # PyMuPDF
 
989
  from PIL import Image, ImageEnhance
990
  import pytesseract
991
 
992
  try:
993
+ # Attempt to import the actual CRF layer for correct Viterbi decoding
994
  from TorchCRF import CRF
995
  except ImportError:
996
+ # Placeholder for environments where it's not yet installed, enabling model definition
 
 
 
997
  class CRF:
998
+ def __init__(self, *args, **kwargs):
999
+ pass
1000
+ # Fallback to simple argmax decoding if the CRF module is missing
1001
+ def viterbi_decode(self, emissions, mask):
1002
+ return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
1003
 
 
1004
 
1005
  # ========== CONFIG (Must match Training Script) ==========
 
1006
  MODEL_FILE = "model_CAT.pt"
1007
  VOCAB_FILE = "vocabs_CAT.pkl"
1008
 
 
1017
  INFERENCE_CHUNK_SIZE = 256
1018
 
1019
  # ========== LABELS (Must match Training Script) ==========
1020
+ # Including PASSAGE for the new structuring logic
1021
+ LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE", "B-PASSAGE", "I-PASSAGE"]
1022
  LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
1023
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
1024
 
1025
 
1026
  # =========================================================
1027
+ # 1. Core Classes (Vocab, CharCNNEncoder, MCQTagger)
1028
  # =========================================================
1029
 
1030
  class Vocab:
 
1031
  def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
1032
  self.min_freq = min_freq
1033
  self.unk_token = unk_token
1034
  self.pad_token = pad_token
1035
  self.freq = Counter()
1036
+ self.itos = []
1037
+ self.stoi = {}
1038
 
1039
  def add_sentence(self, toks):
1040
  self.freq.update(toks)
 
1049
  return len(self.itos)
1050
 
1051
  def __getitem__(self, token: str) -> int:
 
1052
  return self.stoi.get(token, self.stoi[self.unk_token])
1053
 
1054
  def __getstate__(self):
 
1070
 
1071
 
1072
  def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
1073
+ """Loads word and character vocabularies."""
1074
  try:
1075
  absolute_path = os.path.abspath(path)
 
 
1076
  with open(absolute_path, "rb") as f:
1077
  word_vocab, char_vocab = pickle.load(f)
1078
  if len(word_vocab) <= 2:
1079
+ raise IndexError("CRITICAL: Word vocabulary size is too small.")
1080
  return word_vocab, char_vocab
 
 
1081
  except Exception as e:
1082
  raise RuntimeError(f"Error loading vocabs from {path}: {e}")
1083
 
 
1121
 
1122
  if lengths.max().item() == 0:
1123
  B, L = enc_in.size(0), enc_in.size(1)
1124
+ # Return zero tensor if batch is empty
1125
  return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
1126
 
1127
  packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
 
1132
 
1133
  def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
1134
  emissions = self.forward_emissions(words, chars, bboxes, mask)
 
1135
  return self.crf.viterbi_decode(emissions, mask=mask)
1136
 
1137
 
1138
  # =========================================================
1139
+ # 2. PDF Processing Functions
1140
  # =========================================================
1141
 
1142
  def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
1143
+ """Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result."""
 
 
 
1144
  try:
 
1145
  pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
1146
+ if pix.n - pix.alpha > 3:
1147
  pix = fitz.Pixmap(fitz.csRGB, pix)
1148
 
1149
  img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
1150
 
1151
+ # Preprocessing for Tesseract
1152
  img_pil = img_pil.convert('L')
1153
  img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
1154
  img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
1155
 
 
1156
  ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
1157
 
1158
  ocr_tokens = []
1159
  for i in range(len(ocr_data['text'])):
1160
  word = ocr_data['text'][i]
1161
  conf = ocr_data['conf'][i]
 
1162
 
 
1163
  if word.strip() and int(conf) > 50:
1164
+ left, top, width, height = (ocr_data[k][i] for k in ['left', 'top', 'width', 'height'])
 
 
 
 
 
 
1165
  scale = page_width / pix.width
1166
 
1167
  raw_bbox = [
1168
+ left * scale, top * scale, (left + width) * scale, (top + height) * scale
 
 
 
1169
  ]
1170
 
 
1171
  normalized_bbox = [
1172
  (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
1173
  (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
 
1184
  return ocr_tokens
1185
 
1186
  except Exception as e:
 
1187
  print(f"OCR fallback failed: {e}")
1188
  return []
1189
 
1190
 
1191
  def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
1192
+ """Extracts words and bboxes using PyMuPDF text layer and falls back to OCR."""
 
 
 
 
1193
  all_tokens = []
1194
  try:
1195
  doc = fitz.open(pdf_path)
 
1198
  page_width, page_height = page.rect.width, page.rect.height
1199
  page_tokens = []
1200
 
1201
+ # 1. Primary Extraction: PyMuPDF's word structure
 
1202
  word_list = page.get_text("words", sort=True)
1203
 
1204
  if word_list:
 
1206
  word = word_data[4]
1207
  raw_bbox = word_data[:4]
1208
 
 
1209
  normalized_bbox = [
1210
  (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
1211
  (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
 
1238
 
1239
  def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
1240
  chunk_size: int) -> List[Dict[str, Any]]:
1241
+ """Chunks the token list, converts to IDs, and prepares batches for inference."""
 
 
 
1242
  all_batches = []
1243
 
1244
  for i in range(0, len(all_tokens), chunk_size):
 
1272
  "chars": char_pad,
1273
  "bboxes": bbox_pad,
1274
  "mask": mask,
1275
+ "original_tokens": chunk
1276
  })
1277
 
1278
  return all_batches
1279
 
1280
 
1281
  # =========================================================
1282
+ # 3. Model Loading and Caching (Global Variables Defined Here!)
1283
  # =========================================================
1284
 
1285
+ # Global variables (MODEL, VOCABS) are defined here for use in the wrapper function
1286
+ WORD_VOCAB = None
1287
+ CHAR_VOCAB = None
1288
+ MODEL = None
1289
+
1290
  try:
1291
  WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE)
1292
  MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE)
 
1294
  MODEL.eval()
1295
  print("βœ… Model and Vocabs loaded successfully (Cached).")
1296
  except Exception as e:
1297
+ # This prevents the app from crashing if the model files are missing on startup
1298
  print(f"❌ Initial Model/Vocab Load Failure: {e}")
1299
+ print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are found.")
1300
 
1301
 
1302
  # =========================================================
1303
+ # 4. Structuring Logic (Converts BIO to clean JSON)
1304
  # =========================================================
1305
 
1306
+ def finalize_passage_to_item(item, passage_buffer):
1307
+ """Adds passage text to the current item and clears the buffer."""
1308
+ if passage_buffer:
1309
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
1310
+ if item.get('passage'):
1311
+ item['passage'] += ' ' + passage_text
1312
+ else:
1313
+ item['passage'] = passage_text
1314
+ passage_buffer.clear()
1315
+ return item
1316
+
1317
+ def convert_bio_to_structured_json_strict(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
1318
  """
1319
+ Converts a list of {word, predicted_label} tokens into structured MCQ JSON format.
1320
+ """
1321
+ structured_data = []
1322
+ current_item = None
1323
+ current_option_key = None
1324
+ current_passage_buffer = []
1325
+ current_text_buffer = []
1326
+
1327
+ first_question_started = False
1328
+ last_entity_type = None
1329
+
1330
+ for item in predictions:
1331
+ word = item['word']
1332
+ label = item['predicted_label']
1333
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
1334
+
1335
+ current_text_buffer.append(word)
1336
+
1337
+ is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE')
1338
+
1339
+ # --- BEFORE FIRST QUESTION/METADATA HANDLING ---
1340
+ if not first_question_started and label != 'B-QUESTION' and not is_passage_label:
1341
+ continue
1342
+
1343
+ # --- PASSAGE HANDLING (Before question start) ---
1344
+ if not first_question_started and is_passage_label:
1345
+ if label == 'B-PASSAGE' or (label == 'I-PASSAGE' and last_entity_type == 'PASSAGE'):
1346
+ current_passage_buffer.append(word)
1347
+ last_entity_type = 'PASSAGE'
1348
+ continue
1349
+
1350
+ # --- NEW QUESTION START (B-QUESTION) ---
1351
+ if label == 'B-QUESTION':
1352
+ # 1. Capture leading text/passage as METADATA
1353
+ if not first_question_started:
1354
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
1355
+ if header_text or current_passage_buffer:
1356
+ metadata_item = {'type': 'METADATA'}
1357
+ metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
1358
+ if header_text:
1359
+ metadata_item['text'] = header_text
1360
+ structured_data.append(metadata_item)
1361
+
1362
+ first_question_started = True
1363
+ current_text_buffer = [word]
1364
+
1365
+ # 2. Save previous question block
1366
+ elif current_item is not None:
1367
+ current_item = finalize_passage_to_item(current_item, current_passage_buffer)
1368
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
1369
+ structured_data.append(current_item)
1370
+ current_text_buffer = [word]
1371
+
1372
+ # 3. Initialize new question
1373
+ current_item = {
1374
+ 'type': 'MCQ',
1375
+ 'question': word,
1376
+ 'options_text': {},
1377
+ 'answer': '',
1378
+ 'text': ''
1379
+ }
1380
+ current_option_key = None
1381
+ last_entity_type = 'QUESTION'
1382
+ continue
1383
+
1384
+ # --- IF INSIDE A QUESTION BLOCK ---
1385
+ if current_item is not None:
1386
+
1387
+ if label.startswith('B-'):
1388
+ last_entity_type = entity_type
1389
+
1390
+ if entity_type == 'PASSAGE':
1391
+ finalize_passage_to_item(current_item, current_passage_buffer)
1392
+ current_passage_buffer.append(word)
1393
+ elif entity_type == 'OPTION':
1394
+ current_option_key = word
1395
+ current_item['options_text'][current_option_key] = word
1396
+ current_passage_buffer = []
1397
+ elif entity_type == 'ANSWER':
1398
+ current_item['answer'] = word
1399
+ current_option_key = None
1400
+ current_passage_buffer = []
1401
+ elif entity_type == 'QUESTION':
1402
+ current_item['question'] += f' {word}'
1403
+ current_passage_buffer = []
1404
+
1405
+ elif label.startswith('I-'):
1406
+ if entity_type == 'QUESTION' and last_entity_type == 'QUESTION':
1407
+ current_item['question'] += f' {word}'
1408
+ elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None:
1409
+ current_item['options_text'][current_option_key] += f' {word}'
1410
+ elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER':
1411
+ current_item['answer'] += f' {word}'
1412
+ elif entity_type == 'PASSAGE' and last_entity_type == 'PASSAGE':
1413
+ current_passage_buffer.append(word)
1414
+
1415
+ elif label == 'O':
1416
+ pass
1417
+
1418
+ # --- Finalize last item ---
1419
+ if current_item is not None:
1420
+ current_item = finalize_passage_to_item(current_item, current_passage_buffer)
1421
+ current_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
1422
+ structured_data.append(current_item)
1423
+ elif not structured_data and current_passage_buffer:
1424
+ # Case: Only passage/metadata was present in the whole document
1425
+ metadata_item = {'type': 'METADATA'}
1426
+ metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
1427
+ metadata_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
1428
+ structured_data.append(metadata_item)
1429
+
1430
+
1431
+ # --- FINAL CLEANUP ---
1432
+ for item in structured_data:
1433
+ # Clean up all text fields for excessive whitespace
1434
+ item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
1435
+ if 'passage' in item:
1436
+ item['passage'] = re.sub(r'\s{2,}', ' ', item['passage']).strip()
1437
+ if not item['passage']:
1438
+ del item['passage']
1439
+ for field in ['question', 'answer']:
1440
+ if field in item:
1441
+ item[field] = re.sub(r'\s{2,}', ' ', item[field]).strip()
1442
+ if 'options_text' in item:
1443
+ for k, v in item['options_text'].items():
1444
+ item['options_text'][k] = re.sub(r'\s{2,}', ' ', v).strip()
1445
+
1446
+ return structured_data
1447
+
1448
 
1449
+ # =========================================================
1450
+ # 5. The Gradio Inference Wrapper Function (Main Entry Point)
1451
+ # =========================================================
1452
 
1453
+ def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
1454
+ """
1455
+ Wraps the entire two-stage pipeline: (1) Tagging -> (2) Structuring.
1456
  """
1457
+ # Uses global variables defined in Section 3
1458
  if MODEL is None:
1459
  return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", []
1460
 
1461
  pdf_path = pdf_file
1462
+ raw_predictions = []
1463
 
1464
  try:
1465
+ # 1. Stage 1: PDF Processing and BIO Tagging
1466
  all_tokens = extract_tokens_from_pdf(pdf_path)
 
 
1467
 
1468
+ if not all_tokens:
1469
+ return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
1470
+
1471
+ # Uses global variables WORD_VOCAB, CHAR_VOCAB, INFERENCE_CHUNK_SIZE
1472
+ batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
1473
+
1474
+ with torch.no_grad():
1475
+ for batch in batches:
1476
+ words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
1477
+ preds_batch = MODEL(words, chars, bboxes, mask)
1478
+ predictions = preds_batch[0]
1479
+ original_tokens = batch["original_tokens"]
1480
+
1481
+ for token_data, pred_idx in zip(original_tokens, predictions):
1482
+ # Uses global variable IDX2LABEL
1483
+ raw_predictions.append({
1484
+ "word": token_data["word"],
1485
+ "bbox": token_data["raw_bbox"],
1486
+ "predicted_label": IDX2LABEL[pred_idx]
1487
+ })
1488
 
1489
+ # 2. Stage 2: Structured JSON Conversion
1490
+ structured_output = convert_bio_to_structured_json_strict(raw_predictions)
1491
 
1492
+ mcq_count = len([i for i in structured_output if i.get('type') == 'MCQ'])
1493
+ status_message = f"βœ… Conversion complete. Found {mcq_count} MCQ items and {len(structured_output) - mcq_count} Metadata blocks."
 
 
 
 
1494
 
1495
+ return status_message, structured_output
1496
 
1497
+ except RuntimeError as e:
1498
+ return f"❌ PDF Processing Error: {e}", []
1499
+ except Exception as e:
1500
+ return f"❌ An unexpected processing error occurred: {e}", []
1501
 
1502
 
1503
  # =========================================================
1504
+ # 6. Define and Launch the Gradio Interface
1505
  # =========================================================
1506
 
1507
  if __name__ == "__main__":
1508
+ title = "MCQ Document Structure Tagger (Bi-LSTM-CRF) - Structured Output"
1509
+ description = "Upload a PDF document. The system processes it in two stages: 1) BIO-Tagging for structural elements (Question, Option, Answer, Passage) and 2) Converting those tags into a clean, structured JSON list of MCQ items."
1510
 
 
1511
  demo = gr.Interface(
1512
  fn=gradio_inference_wrapper,
1513
+ # Ensure only PDF files are accepted
1514
+ inputs=gr.File(label="Upload PDF Document", file_types=['pdf']),
1515
  outputs=[
1516
  gr.Textbox(label="Status Message", interactive=False),
1517
+ gr.JSON(label="Structured MCQ JSON Output", show_label=True)
1518
  ],
1519
  title=title,
1520
  description=description,
1521
  allow_flagging="never",
 
1522
  concurrency_limit=2
1523
  )
1524
 
1525
+ demo.launch()