heerjtdev commited on
Commit
f1c0953
·
verified ·
1 Parent(s): 02f7b52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +417 -66
app.py CHANGED
@@ -1,3 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn as nn
@@ -11,33 +383,28 @@ from TorchCRF import CRF
11
  # ---------------------------------------------------------
12
  # 1. CONFIGURATION
13
  # ---------------------------------------------------------
14
- # Ensure this filename matches exactly what you uploaded to the Space
15
  MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth"
16
  BASE_MODEL_ID = "microsoft/layoutlmv3-base"
17
 
18
- # Define your labels exactly as they were during training
 
 
19
  LABELS = [
20
  "O",
21
  "B-QUESTION", "I-QUESTION",
22
  "B-OPTION", "I-OPTION",
23
  "B-ANSWER", "I-ANSWER",
24
  "B-SECTION_HEADING", "I-SECTION_HEADING",
25
- "B-PASSAGE", "I-PASSAGE"
 
26
  ]
 
27
  LABEL2ID = {l: i for i, l in enumerate(LABELS)}
28
  ID2LABEL = {i: l for l, i in LABEL2ID.items()}
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
32
 
33
- # ---------------------------------------------------------
34
- # 2. MODEL ARCHITECTURE
35
- # ---------------------------------------------------------
36
- # ⚠️ ACTION REQUIRED:
37
- # Replace this class with the exact class definition of your
38
- # NEW HYBRID MODEL. The class name and structure must match
39
- # what was used when you saved 'layoutlmv3_nonlinear_scratch.pth'.
40
- # ---------------------------------------------------------
41
  # ---------------------------------------------------------
42
  # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF)
43
  # ---------------------------------------------------------
@@ -46,52 +413,47 @@ class HybridModel(nn.Module):
46
  super().__init__()
47
  self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
48
 
49
- # Config for BiLSTM
50
- hidden_size = self.layoutlm.config.hidden_size # Usually 768
51
- lstm_hidden_size = hidden_size // 2 # 384, so bidirectional output is 768
52
 
53
- # BiLSTM Layer
54
- # input_size=768, hidden=384, bidir=True -> output_dim = 384 * 2 = 768
55
  self.lstm = nn.LSTM(
56
- input_size=hidden_size,
57
- hidden_size=lstm_hidden_size,
58
- num_layers=1,
59
  batch_first=True,
60
  bidirectional=True
61
  )
62
 
63
- # Dropout (Optional, check if you used this in training)
64
  self.dropout = nn.Dropout(0.1)
65
 
66
- # Classifier: Maps BiLSTM output (768) to Label count
 
67
  self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels)
68
 
69
- # CRF Layer
70
  self.crf = CRF(num_labels)
71
 
72
  def forward(self, input_ids, bbox, attention_mask, labels=None):
73
- # 1. LayoutLMv3 Base
74
  outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
75
- sequence_output = outputs.last_hidden_state # [Batch, Seq, 768]
76
 
77
- # 2. BiLSTM
78
- # LSTM returns (output, (h_n, c_n)). We only need output.
79
- lstm_output, _ = self.lstm(sequence_output) # [Batch, Seq, 768]
80
 
81
- # 3. Dropout & Classifier
82
  lstm_output = self.dropout(lstm_output)
83
- emissions = self.classifier(lstm_output) # [Batch, Seq, Num_Labels]
84
 
85
- # 4. CRF
86
  if labels is not None:
87
- # Training/Eval (Loss)
88
  log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
89
  return -log_likelihood.mean()
90
  else:
91
- # Inference (Prediction Tags)
92
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
 
93
  # ---------------------------------------------------------
94
- # 3. MODEL LOADING LOGIC
95
  # ---------------------------------------------------------
96
  model = None
97
 
@@ -100,17 +462,18 @@ def load_model():
100
  if model is None:
101
  print(f"🔄 Loading model from {MODEL_FILENAME}...")
102
  if not os.path.exists(MODEL_FILENAME):
103
- raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found. Please upload it to the Files tab of your Space.")
104
 
105
- # Initialize the model structure
106
  model = HybridModel(num_labels=len(LABELS))
107
 
108
- # Load weights
 
 
 
109
  try:
110
- state_dict = torch.load(MODEL_FILENAME, map_location=device)
111
  model.load_state_dict(state_dict)
112
  except RuntimeError as e:
113
- raise RuntimeError(f"❌ State dictionary mismatch. Ensure the 'HybridModel' class structure in app.py matches the model you trained.\nDetails: {e}")
114
 
115
  model.to(device)
116
  model.eval()
@@ -118,7 +481,7 @@ def load_model():
118
  return model
119
 
120
  # ---------------------------------------------------------
121
- # 4. JSON CONVERSION LOGIC (Your Custom Logic)
122
  # ---------------------------------------------------------
123
  def convert_bio_to_structured_json(predictions):
124
  structured_data = []
@@ -138,7 +501,6 @@ def convert_bio_to_structured_json(predictions):
138
  else: item['passage'] = passage_text
139
  passage_buffer.clear()
140
 
141
- # Flatten predictions list if strictly page-separated
142
  flat_predictions = []
143
  for page in predictions:
144
  flat_predictions.extend(page['data'])
@@ -146,9 +508,16 @@ def convert_bio_to_structured_json(predictions):
146
  for idx, item in enumerate(flat_predictions):
147
  word = item['word']
148
  label = item['predicted_label']
 
 
149
  entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
150
- current_text_buffer.append(word)
151
 
 
 
 
 
 
 
152
  previous_entity_type = last_entity_type
153
  is_passage_label = (entity_type == 'PASSAGE')
154
 
@@ -242,7 +611,6 @@ def convert_bio_to_structured_json(predictions):
242
  current_item['text'] = ' '.join(current_text_buffer).strip()
243
  structured_data.append(current_item)
244
 
245
- # Final Cleanup
246
  for item in structured_data:
247
  if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
248
  if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
@@ -250,7 +618,7 @@ def convert_bio_to_structured_json(predictions):
250
  return structured_data
251
 
252
  # ---------------------------------------------------------
253
- # 5. INFERENCE PIPELINE
254
  # ---------------------------------------------------------
255
  def process_pdf(pdf_file):
256
  if pdf_file is None:
@@ -259,7 +627,6 @@ def process_pdf(pdf_file):
259
  try:
260
  active_model = load_model()
261
 
262
- # A. Extract Text and Boxes
263
  extracted_pages = []
264
  with pdfplumber.open(pdf_file.name) as pdf:
265
  for page_idx, page in enumerate(pdf.pages):
@@ -271,28 +638,22 @@ def process_pdf(pdf_file):
271
 
272
  for w in words_data:
273
  text = w['text']
274
- # Normalize bbox to 0-1000 scale
275
  x0 = int((w['x0'] / width) * 1000)
276
  top = int((w['top'] / height) * 1000)
277
  x1 = int((w['x1'] / width) * 1000)
278
  bottom = int((w['bottom'] / height) * 1000)
279
-
280
- # Safety clamp
281
  box = [max(0, min(x0, 1000)), max(0, min(top, 1000)),
282
  max(0, min(x1, 1000)), max(0, min(bottom, 1000))]
283
-
284
  page_tokens.append(text)
285
  page_bboxes.append(box)
286
  extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes})
287
 
288
- # B. Run Inference
289
  raw_predictions = []
290
  for page in extracted_pages:
291
  tokens = page['tokens']
292
  bboxes = page['bboxes']
293
  if not tokens: continue
294
 
295
- # Tokenize
296
  encoding = tokenizer(
297
  tokens,
298
  boxes=bboxes,
@@ -307,18 +668,12 @@ def process_pdf(pdf_file):
307
  bbox = encoding.bbox.to(device)
308
  attention_mask = encoding.attention_mask.to(device)
309
 
310
- # Predict
311
  with torch.no_grad():
312
- # NOTE: If your hybrid model requires 'pixel_values',
313
- # you will need to add image extraction logic above and pass it here.
314
- preds = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
315
-
316
- # Check if preds returns a tuple (loss, tags) or just tags
317
- # The CRF implementation usually returns a list of lists of tags in viterbi_decode
318
- pred_tags = preds[0] if isinstance(preds, tuple) else preds[0]
319
- # Note: Standard CRF.viterbi_decode returns List[List[int]], so [0] gets the first batch item
320
 
321
- # Alignment
322
  word_ids = encoding.word_ids()
323
  aligned_data = []
324
  prev_word_idx = None
@@ -326,20 +681,16 @@ def process_pdf(pdf_file):
326
  for i, word_idx in enumerate(word_ids):
327
  if word_idx is None: continue
328
  if word_idx != prev_word_idx:
329
- # pred_tags is likely a list of ints.
330
- # If pred_tags[i] fails, your max_length might be cutting off tags,
331
- # or the model output shape differs from the token length.
332
  if i < len(pred_tags):
333
  label_id = pred_tags[i]
 
334
  label_str = ID2LABEL.get(label_id, "O")
335
  aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str})
336
  prev_word_idx = word_idx
337
  raw_predictions.append({"data": aligned_data})
338
 
339
- # C. Convert to Structured JSON
340
  final_json = convert_bio_to_structured_json(raw_predictions)
341
 
342
- # Save output
343
  output_filename = "structured_output.json"
344
  with open(output_filename, "w", encoding="utf-8") as f:
345
  json.dump(final_json, f, indent=2, ensure_ascii=False)
@@ -360,7 +711,7 @@ iface = gr.Interface(
360
  gr.File(label="Download JSON Output"),
361
  gr.Textbox(label="Status Log", lines=10)
362
  ],
363
- title="Hybrid Model Inference: PDF to JSON",
364
  description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.",
365
  flagging_mode="never"
366
  )
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # import torch.nn as nn
4
+ # import pdfplumber
5
+ # import json
6
+ # import os
7
+ # import re
8
+ # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
+ # from TorchCRF import CRF
10
+
11
+ # # ---------------------------------------------------------
12
+ # # 1. CONFIGURATION
13
+ # # ---------------------------------------------------------
14
+ # # Ensure this filename matches exactly what you uploaded to the Space
15
+ # MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth"
16
+ # BASE_MODEL_ID = "microsoft/layoutlmv3-base"
17
+
18
+ # # Define your labels exactly as they were during training
19
+ # LABELS = [
20
+ # "O",
21
+ # "B-QUESTION", "I-QUESTION",
22
+ # "B-OPTION", "I-OPTION",
23
+ # "B-ANSWER", "I-ANSWER",
24
+ # "B-SECTION_HEADING", "I-SECTION_HEADING",
25
+ # "B-PASSAGE", "I-PASSAGE"
26
+ # ]
27
+ # LABEL2ID = {l: i for i, l in enumerate(LABELS)}
28
+ # ID2LABEL = {i: l for l, i in LABEL2ID.items()}
29
+
30
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
32
+
33
+ # # ---------------------------------------------------------
34
+ # # 2. MODEL ARCHITECTURE
35
+ # # ---------------------------------------------------------
36
+ # # ⚠️ ACTION REQUIRED:
37
+ # # Replace this class with the exact class definition of your
38
+ # # NEW HYBRID MODEL. The class name and structure must match
39
+ # # what was used when you saved 'layoutlmv3_nonlinear_scratch.pth'.
40
+ # # ---------------------------------------------------------
41
+ # # ---------------------------------------------------------
42
+ # # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF)
43
+ # # ---------------------------------------------------------
44
+ # class HybridModel(nn.Module):
45
+ # def __init__(self, num_labels):
46
+ # super().__init__()
47
+ # self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
48
+
49
+ # # Config for BiLSTM
50
+ # hidden_size = self.layoutlm.config.hidden_size # Usually 768
51
+ # lstm_hidden_size = hidden_size // 2 # 384, so bidirectional output is 768
52
+
53
+ # # BiLSTM Layer
54
+ # # input_size=768, hidden=384, bidir=True -> output_dim = 384 * 2 = 768
55
+ # self.lstm = nn.LSTM(
56
+ # input_size=hidden_size,
57
+ # hidden_size=lstm_hidden_size,
58
+ # num_layers=1,
59
+ # batch_first=True,
60
+ # bidirectional=True
61
+ # )
62
+
63
+ # # Dropout (Optional, check if you used this in training)
64
+ # self.dropout = nn.Dropout(0.1)
65
+
66
+ # # Classifier: Maps BiLSTM output (768) to Label count
67
+ # self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels)
68
+
69
+ # # CRF Layer
70
+ # self.crf = CRF(num_labels)
71
+
72
+ # def forward(self, input_ids, bbox, attention_mask, labels=None):
73
+ # # 1. LayoutLMv3 Base
74
+ # outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
75
+ # sequence_output = outputs.last_hidden_state # [Batch, Seq, 768]
76
+
77
+ # # 2. BiLSTM
78
+ # # LSTM returns (output, (h_n, c_n)). We only need output.
79
+ # lstm_output, _ = self.lstm(sequence_output) # [Batch, Seq, 768]
80
+
81
+ # # 3. Dropout & Classifier
82
+ # lstm_output = self.dropout(lstm_output)
83
+ # emissions = self.classifier(lstm_output) # [Batch, Seq, Num_Labels]
84
+
85
+ # # 4. CRF
86
+ # if labels is not None:
87
+ # # Training/Eval (Loss)
88
+ # log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
89
+ # return -log_likelihood.mean()
90
+ # else:
91
+ # # Inference (Prediction Tags)
92
+ # return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
93
+ # # ---------------------------------------------------------
94
+ # # 3. MODEL LOADING LOGIC
95
+ # # ---------------------------------------------------------
96
+ # model = None
97
+
98
+ # def load_model():
99
+ # global model
100
+ # if model is None:
101
+ # print(f"🔄 Loading model from {MODEL_FILENAME}...")
102
+ # if not os.path.exists(MODEL_FILENAME):
103
+ # raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found. Please upload it to the Files tab of your Space.")
104
+
105
+ # # Initialize the model structure
106
+ # model = HybridModel(num_labels=len(LABELS))
107
+
108
+ # # Load weights
109
+ # try:
110
+ # state_dict = torch.load(MODEL_FILENAME, map_location=device)
111
+ # model.load_state_dict(state_dict)
112
+ # except RuntimeError as e:
113
+ # raise RuntimeError(f"❌ State dictionary mismatch. Ensure the 'HybridModel' class structure in app.py matches the model you trained.\nDetails: {e}")
114
+
115
+ # model.to(device)
116
+ # model.eval()
117
+ # print("✅ Model loaded successfully.")
118
+ # return model
119
+
120
+ # # ---------------------------------------------------------
121
+ # # 4. JSON CONVERSION LOGIC (Your Custom Logic)
122
+ # # ---------------------------------------------------------
123
+ # def convert_bio_to_structured_json(predictions):
124
+ # structured_data = []
125
+ # current_item = None
126
+ # current_option_key = None
127
+ # current_passage_buffer = []
128
+ # current_text_buffer = []
129
+ # first_question_started = False
130
+ # last_entity_type = None
131
+ # just_finished_i_option = False
132
+ # is_in_new_passage = False
133
+
134
+ # def finalize_passage_to_item(item, passage_buffer):
135
+ # if passage_buffer:
136
+ # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
137
+ # if item.get('passage'): item['passage'] += ' ' + passage_text
138
+ # else: item['passage'] = passage_text
139
+ # passage_buffer.clear()
140
+
141
+ # # Flatten predictions list if strictly page-separated
142
+ # flat_predictions = []
143
+ # for page in predictions:
144
+ # flat_predictions.extend(page['data'])
145
+
146
+ # for idx, item in enumerate(flat_predictions):
147
+ # word = item['word']
148
+ # label = item['predicted_label']
149
+ # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
150
+ # current_text_buffer.append(word)
151
+
152
+ # previous_entity_type = last_entity_type
153
+ # is_passage_label = (entity_type == 'PASSAGE')
154
+
155
+ # if not first_question_started:
156
+ # if label != 'B-QUESTION' and not is_passage_label:
157
+ # just_finished_i_option = False
158
+ # is_in_new_passage = False
159
+ # continue
160
+ # if is_passage_label:
161
+ # current_passage_buffer.append(word)
162
+ # last_entity_type = 'PASSAGE'
163
+ # just_finished_i_option = False
164
+ # is_in_new_passage = False
165
+ # continue
166
+
167
+ # if label == 'B-QUESTION':
168
+ # if not first_question_started:
169
+ # header_text = ' '.join(current_text_buffer[:-1]).strip()
170
+ # if header_text or current_passage_buffer:
171
+ # metadata_item = {'type': 'METADATA', 'passage': ''}
172
+ # finalize_passage_to_item(metadata_item, current_passage_buffer)
173
+ # if header_text: metadata_item['text'] = header_text
174
+ # structured_data.append(metadata_item)
175
+ # first_question_started = True
176
+ # current_text_buffer = [word]
177
+
178
+ # if current_item is not None:
179
+ # finalize_passage_to_item(current_item, current_passage_buffer)
180
+ # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
181
+ # structured_data.append(current_item)
182
+ # current_text_buffer = [word]
183
+
184
+ # current_item = {
185
+ # 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
186
+ # }
187
+ # current_option_key = None
188
+ # last_entity_type = 'QUESTION'
189
+ # just_finished_i_option = False
190
+ # is_in_new_passage = False
191
+ # continue
192
+
193
+ # if current_item is not None:
194
+ # if is_in_new_passage:
195
+ # if 'new_passage' not in current_item: current_item['new_passage'] = word
196
+ # else: current_item['new_passage'] += f' {word}'
197
+ # if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
198
+ # is_in_new_passage = False
199
+ # if label.startswith(('B-', 'I-')): last_entity_type = entity_type
200
+ # continue
201
+
202
+ # is_in_new_passage = False
203
+ # if label.startswith('B-'):
204
+ # if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
205
+ # finalize_passage_to_item(current_item, current_passage_buffer)
206
+ # current_passage_buffer = []
207
+ # last_entity_type = entity_type
208
+ # if entity_type == 'PASSAGE':
209
+ # if previous_entity_type == 'OPTION' and just_finished_i_option:
210
+ # current_item['new_passage'] = word
211
+ # is_in_new_passage = True
212
+ # else: current_passage_buffer.append(word)
213
+ # elif entity_type == 'OPTION':
214
+ # current_option_key = word
215
+ # current_item['options'][current_option_key] = word
216
+ # just_finished_i_option = False
217
+ # elif entity_type == 'ANSWER':
218
+ # current_item['answer'] = word
219
+ # current_option_key = None
220
+ # just_finished_i_option = False
221
+ # elif entity_type == 'QUESTION':
222
+ # current_item['question'] += f' {word}'
223
+ # just_finished_i_option = False
224
+
225
+ # elif label.startswith('I-'):
226
+ # if entity_type == 'QUESTION': current_item['question'] += f' {word}'
227
+ # elif entity_type == 'PASSAGE':
228
+ # if previous_entity_type == 'OPTION' and just_finished_i_option:
229
+ # current_item['new_passage'] = word
230
+ # is_in_new_passage = True
231
+ # else:
232
+ # if not current_passage_buffer: last_entity_type = 'PASSAGE'
233
+ # current_passage_buffer.append(word)
234
+ # elif entity_type == 'OPTION' and current_option_key is not None:
235
+ # current_item['options'][current_option_key] += f' {word}'
236
+ # just_finished_i_option = True
237
+ # elif entity_type == 'ANSWER': current_item['answer'] += f' {word}'
238
+ # just_finished_i_option = (entity_type == 'OPTION')
239
+
240
+ # if current_item is not None:
241
+ # finalize_passage_to_item(current_item, current_passage_buffer)
242
+ # current_item['text'] = ' '.join(current_text_buffer).strip()
243
+ # structured_data.append(current_item)
244
+
245
+ # # Final Cleanup
246
+ # for item in structured_data:
247
+ # if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
248
+ # if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
249
+
250
+ # return structured_data
251
+
252
+ # # ---------------------------------------------------------
253
+ # # 5. INFERENCE PIPELINE
254
+ # # ---------------------------------------------------------
255
+ # def process_pdf(pdf_file):
256
+ # if pdf_file is None:
257
+ # return None, "⚠️ Please upload a PDF file."
258
+
259
+ # try:
260
+ # active_model = load_model()
261
+
262
+ # # A. Extract Text and Boxes
263
+ # extracted_pages = []
264
+ # with pdfplumber.open(pdf_file.name) as pdf:
265
+ # for page_idx, page in enumerate(pdf.pages):
266
+ # width, height = page.width, page.height
267
+ # words_data = page.extract_words()
268
+
269
+ # page_tokens = []
270
+ # page_bboxes = []
271
+
272
+ # for w in words_data:
273
+ # text = w['text']
274
+ # # Normalize bbox to 0-1000 scale
275
+ # x0 = int((w['x0'] / width) * 1000)
276
+ # top = int((w['top'] / height) * 1000)
277
+ # x1 = int((w['x1'] / width) * 1000)
278
+ # bottom = int((w['bottom'] / height) * 1000)
279
+
280
+ # # Safety clamp
281
+ # box = [max(0, min(x0, 1000)), max(0, min(top, 1000)),
282
+ # max(0, min(x1, 1000)), max(0, min(bottom, 1000))]
283
+
284
+ # page_tokens.append(text)
285
+ # page_bboxes.append(box)
286
+ # extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes})
287
+
288
+ # # B. Run Inference
289
+ # raw_predictions = []
290
+ # for page in extracted_pages:
291
+ # tokens = page['tokens']
292
+ # bboxes = page['bboxes']
293
+ # if not tokens: continue
294
+
295
+ # # Tokenize
296
+ # encoding = tokenizer(
297
+ # tokens,
298
+ # boxes=bboxes,
299
+ # return_tensors="pt",
300
+ # padding="max_length",
301
+ # truncation=True,
302
+ # max_length=512,
303
+ # return_offsets_mapping=True
304
+ # )
305
+
306
+ # input_ids = encoding.input_ids.to(device)
307
+ # bbox = encoding.bbox.to(device)
308
+ # attention_mask = encoding.attention_mask.to(device)
309
+
310
+ # # Predict
311
+ # with torch.no_grad():
312
+ # # NOTE: If your hybrid model requires 'pixel_values',
313
+ # # you will need to add image extraction logic above and pass it here.
314
+ # preds = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
315
+
316
+ # # Check if preds returns a tuple (loss, tags) or just tags
317
+ # # The CRF implementation usually returns a list of lists of tags in viterbi_decode
318
+ # pred_tags = preds[0] if isinstance(preds, tuple) else preds[0]
319
+ # # Note: Standard CRF.viterbi_decode returns List[List[int]], so [0] gets the first batch item
320
+
321
+ # # Alignment
322
+ # word_ids = encoding.word_ids()
323
+ # aligned_data = []
324
+ # prev_word_idx = None
325
+
326
+ # for i, word_idx in enumerate(word_ids):
327
+ # if word_idx is None: continue
328
+ # if word_idx != prev_word_idx:
329
+ # # pred_tags is likely a list of ints.
330
+ # # If pred_tags[i] fails, your max_length might be cutting off tags,
331
+ # # or the model output shape differs from the token length.
332
+ # if i < len(pred_tags):
333
+ # label_id = pred_tags[i]
334
+ # label_str = ID2LABEL.get(label_id, "O")
335
+ # aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str})
336
+ # prev_word_idx = word_idx
337
+ # raw_predictions.append({"data": aligned_data})
338
+
339
+ # # C. Convert to Structured JSON
340
+ # final_json = convert_bio_to_structured_json(raw_predictions)
341
+
342
+ # # Save output
343
+ # output_filename = "structured_output.json"
344
+ # with open(output_filename, "w", encoding="utf-8") as f:
345
+ # json.dump(final_json, f, indent=2, ensure_ascii=False)
346
+
347
+ # return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items."
348
+
349
+ # except Exception as e:
350
+ # import traceback
351
+ # return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}"
352
+
353
+ # # ---------------------------------------------------------
354
+ # # 6. GRADIO INTERFACE
355
+ # # ---------------------------------------------------------
356
+ # iface = gr.Interface(
357
+ # fn=process_pdf,
358
+ # inputs=gr.File(label="Upload PDF", file_types=[".pdf"]),
359
+ # outputs=[
360
+ # gr.File(label="Download JSON Output"),
361
+ # gr.Textbox(label="Status Log", lines=10)
362
+ # ],
363
+ # title="Hybrid Model Inference: PDF to JSON",
364
+ # description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.",
365
+ # flagging_mode="never"
366
+ # )
367
+
368
+ # if __name__ == "__main__":
369
+ # iface.launch()
370
+
371
+
372
+
373
  import gradio as gr
374
  import torch
375
  import torch.nn as nn
 
383
  # ---------------------------------------------------------
384
  # 1. CONFIGURATION
385
  # ---------------------------------------------------------
 
386
  MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth"
387
  BASE_MODEL_ID = "microsoft/layoutlmv3-base"
388
 
389
+ # Labels: 11 Standard BIO tags + 2 Special tokens = 13 Total
390
+ # NOTE: If your output labels look "scrambled" (e.g., Questions detected as Options),
391
+ # try moving "UNK" and "PAD" to the BEGINNING of this list (indices 0 and 1).
392
  LABELS = [
393
  "O",
394
  "B-QUESTION", "I-QUESTION",
395
  "B-OPTION", "I-OPTION",
396
  "B-ANSWER", "I-ANSWER",
397
  "B-SECTION_HEADING", "I-SECTION_HEADING",
398
+ "B-PASSAGE", "I-PASSAGE",
399
+ "UNK", "PAD" # Added to match the 13-label count in your weights
400
  ]
401
+
402
  LABEL2ID = {l: i for i, l in enumerate(LABELS)}
403
  ID2LABEL = {i: l for l, i in LABEL2ID.items()}
404
 
405
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
407
 
 
 
 
 
 
 
 
 
408
  # ---------------------------------------------------------
409
  # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF)
410
  # ---------------------------------------------------------
 
413
  super().__init__()
414
  self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
415
 
416
+ # Structure derived from your error log:
417
+ # Weight shape [1024, 768] implies hidden_size = 256 (1024/4)
418
+ lstm_hidden_size = 256
419
 
 
 
420
  self.lstm = nn.LSTM(
421
+ input_size=768, # LayoutLMv3 output size
422
+ hidden_size=lstm_hidden_size,
423
+ num_layers=2, # Error log showed 'l1' weights, meaning 2 layers
424
  batch_first=True,
425
  bidirectional=True
426
  )
427
 
 
428
  self.dropout = nn.Dropout(0.1)
429
 
430
+ # Classifier input = lstm_hidden * 2 (bidirectional) = 256 * 2 = 512
431
+ # This matches your error log shape [13, 512]
432
  self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels)
433
 
 
434
  self.crf = CRF(num_labels)
435
 
436
  def forward(self, input_ids, bbox, attention_mask, labels=None):
 
437
  outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
438
+ sequence_output = outputs.last_hidden_state
439
 
440
+ # BiLSTM
441
+ lstm_output, _ = self.lstm(sequence_output)
 
442
 
443
+ # Classifier
444
  lstm_output = self.dropout(lstm_output)
445
+ emissions = self.classifier(lstm_output)
446
 
 
447
  if labels is not None:
448
+ # Training/Eval loss
449
  log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
450
  return -log_likelihood.mean()
451
  else:
452
+ # Inference prediction
453
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
454
+
455
  # ---------------------------------------------------------
456
+ # 3. MODEL LOADING
457
  # ---------------------------------------------------------
458
  model = None
459
 
 
462
  if model is None:
463
  print(f"🔄 Loading model from {MODEL_FILENAME}...")
464
  if not os.path.exists(MODEL_FILENAME):
465
+ raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found.")
466
 
 
467
  model = HybridModel(num_labels=len(LABELS))
468
 
469
+ # Load state dictionary
470
+ state_dict = torch.load(MODEL_FILENAME, map_location=device)
471
+
472
+ # Try loading. If labels are wrong, this will still throw a shape error.
473
  try:
 
474
  model.load_state_dict(state_dict)
475
  except RuntimeError as e:
476
+ raise RuntimeError(f"❌ Weight mismatch! \nYour model has {len(LABELS)} labels defined in script.\nCheck if 'LABELS' list needs reordering or resizing.\nDetailed Error: {e}")
477
 
478
  model.to(device)
479
  model.eval()
 
481
  return model
482
 
483
  # ---------------------------------------------------------
484
+ # 4. JSON CONVERSION LOGIC
485
  # ---------------------------------------------------------
486
  def convert_bio_to_structured_json(predictions):
487
  structured_data = []
 
501
  else: item['passage'] = passage_text
502
  passage_buffer.clear()
503
 
 
504
  flat_predictions = []
505
  for page in predictions:
506
  flat_predictions.extend(page['data'])
 
508
  for idx, item in enumerate(flat_predictions):
509
  word = item['word']
510
  label = item['predicted_label']
511
+
512
+ # Clean label (remove B- / I-)
513
  entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
 
514
 
515
+ # Skip special tokens if they appear in prediction
516
+ if label in ["UNK", "PAD", "O"]:
517
+ current_text_buffer.append(word)
518
+ continue
519
+
520
+ current_text_buffer.append(word)
521
  previous_entity_type = last_entity_type
522
  is_passage_label = (entity_type == 'PASSAGE')
523
 
 
611
  current_item['text'] = ' '.join(current_text_buffer).strip()
612
  structured_data.append(current_item)
613
 
 
614
  for item in structured_data:
615
  if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
616
  if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
 
618
  return structured_data
619
 
620
  # ---------------------------------------------------------
621
+ # 5. PROCESSING PIPELINE
622
  # ---------------------------------------------------------
623
  def process_pdf(pdf_file):
624
  if pdf_file is None:
 
627
  try:
628
  active_model = load_model()
629
 
 
630
  extracted_pages = []
631
  with pdfplumber.open(pdf_file.name) as pdf:
632
  for page_idx, page in enumerate(pdf.pages):
 
638
 
639
  for w in words_data:
640
  text = w['text']
 
641
  x0 = int((w['x0'] / width) * 1000)
642
  top = int((w['top'] / height) * 1000)
643
  x1 = int((w['x1'] / width) * 1000)
644
  bottom = int((w['bottom'] / height) * 1000)
 
 
645
  box = [max(0, min(x0, 1000)), max(0, min(top, 1000)),
646
  max(0, min(x1, 1000)), max(0, min(bottom, 1000))]
 
647
  page_tokens.append(text)
648
  page_bboxes.append(box)
649
  extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes})
650
 
 
651
  raw_predictions = []
652
  for page in extracted_pages:
653
  tokens = page['tokens']
654
  bboxes = page['bboxes']
655
  if not tokens: continue
656
 
 
657
  encoding = tokenizer(
658
  tokens,
659
  boxes=bboxes,
 
668
  bbox = encoding.bbox.to(device)
669
  attention_mask = encoding.attention_mask.to(device)
670
 
 
671
  with torch.no_grad():
672
+ # Get the tag indices from the CRF layer
673
+ pred_tags = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
674
+ # If batch size is 1, pred_tags is a list of lists: [[tags...]]
675
+ pred_tags = pred_tags[0]
 
 
 
 
676
 
 
677
  word_ids = encoding.word_ids()
678
  aligned_data = []
679
  prev_word_idx = None
 
681
  for i, word_idx in enumerate(word_ids):
682
  if word_idx is None: continue
683
  if word_idx != prev_word_idx:
 
 
 
684
  if i < len(pred_tags):
685
  label_id = pred_tags[i]
686
+ # Safe retrieval of label string
687
  label_str = ID2LABEL.get(label_id, "O")
688
  aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str})
689
  prev_word_idx = word_idx
690
  raw_predictions.append({"data": aligned_data})
691
 
 
692
  final_json = convert_bio_to_structured_json(raw_predictions)
693
 
 
694
  output_filename = "structured_output.json"
695
  with open(output_filename, "w", encoding="utf-8") as f:
696
  json.dump(final_json, f, indent=2, ensure_ascii=False)
 
711
  gr.File(label="Download JSON Output"),
712
  gr.Textbox(label="Status Log", lines=10)
713
  ],
714
+ title="LayoutLMv3 + BiLSTM Hybrid Model Inference",
715
  description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.",
716
  flagging_mode="never"
717
  )