heerjtdev commited on
Commit
0133c60
·
verified ·
1 Parent(s): 85859d3

Rename train.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +369 -0
  2. train.py +0 -244
app.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import pdfplumber
5
+ import json
6
+ import os
7
+ import re
8
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
+ from TorchCRF import CRF
10
+
11
+ # ---------------------------------------------------------
12
+ # 1. CONFIGURATION
13
+ # ---------------------------------------------------------
14
+ # Ensure this filename matches exactly what you uploaded to the Space
15
+ MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth"
16
+ BASE_MODEL_ID = "microsoft/layoutlmv3-base"
17
+
18
+ # Define your labels exactly as they were during training
19
+ LABELS = [
20
+ "O",
21
+ "B-QUESTION", "I-QUESTION",
22
+ "B-OPTION", "I-OPTION",
23
+ "B-ANSWER", "I-ANSWER",
24
+ "B-SECTION_HEADING", "I-SECTION_HEADING",
25
+ "B-PASSAGE", "I-PASSAGE"
26
+ ]
27
+ LABEL2ID = {l: i for i, l in enumerate(LABELS)}
28
+ ID2LABEL = {i: l for l, i in LABEL2ID.items()}
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
32
+
33
+ # ---------------------------------------------------------
34
+ # 2. MODEL ARCHITECTURE
35
+ # ---------------------------------------------------------
36
+ # ⚠️ ACTION REQUIRED:
37
+ # Replace this class with the exact class definition of your
38
+ # NEW HYBRID MODEL. The class name and structure must match
39
+ # what was used when you saved 'layoutlmv3_nonlinear_scratch.pth'.
40
+ # ---------------------------------------------------------
41
+ # ---------------------------------------------------------
42
+ # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF)
43
+ # ---------------------------------------------------------
44
+ class HybridModel(nn.Module):
45
+ def __init__(self, num_labels):
46
+ super().__init__()
47
+ self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
48
+
49
+ # Config for BiLSTM
50
+ hidden_size = self.layoutlm.config.hidden_size # Usually 768
51
+ lstm_hidden_size = hidden_size // 2 # 384, so bidirectional output is 768
52
+
53
+ # BiLSTM Layer
54
+ # input_size=768, hidden=384, bidir=True -> output_dim = 384 * 2 = 768
55
+ self.lstm = nn.LSTM(
56
+ input_size=hidden_size,
57
+ hidden_size=lstm_hidden_size,
58
+ num_layers=1,
59
+ batch_first=True,
60
+ bidirectional=True
61
+ )
62
+
63
+ # Dropout (Optional, check if you used this in training)
64
+ self.dropout = nn.Dropout(0.1)
65
+
66
+ # Classifier: Maps BiLSTM output (768) to Label count
67
+ self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels)
68
+
69
+ # CRF Layer
70
+ self.crf = CRF(num_labels)
71
+
72
+ def forward(self, input_ids, bbox, attention_mask, labels=None):
73
+ # 1. LayoutLMv3 Base
74
+ outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
75
+ sequence_output = outputs.last_hidden_state # [Batch, Seq, 768]
76
+
77
+ # 2. BiLSTM
78
+ # LSTM returns (output, (h_n, c_n)). We only need output.
79
+ lstm_output, _ = self.lstm(sequence_output) # [Batch, Seq, 768]
80
+
81
+ # 3. Dropout & Classifier
82
+ lstm_output = self.dropout(lstm_output)
83
+ emissions = self.classifier(lstm_output) # [Batch, Seq, Num_Labels]
84
+
85
+ # 4. CRF
86
+ if labels is not None:
87
+ # Training/Eval (Loss)
88
+ log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
89
+ return -log_likelihood.mean()
90
+ else:
91
+ # Inference (Prediction Tags)
92
+ return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
93
+ # ---------------------------------------------------------
94
+ # 3. MODEL LOADING LOGIC
95
+ # ---------------------------------------------------------
96
+ model = None
97
+
98
+ def load_model():
99
+ global model
100
+ if model is None:
101
+ print(f"🔄 Loading model from {MODEL_FILENAME}...")
102
+ if not os.path.exists(MODEL_FILENAME):
103
+ raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found. Please upload it to the Files tab of your Space.")
104
+
105
+ # Initialize the model structure
106
+ model = HybridModel(num_labels=len(LABELS))
107
+
108
+ # Load weights
109
+ try:
110
+ state_dict = torch.load(MODEL_FILENAME, map_location=device)
111
+ model.load_state_dict(state_dict)
112
+ except RuntimeError as e:
113
+ raise RuntimeError(f"❌ State dictionary mismatch. Ensure the 'HybridModel' class structure in app.py matches the model you trained.\nDetails: {e}")
114
+
115
+ model.to(device)
116
+ model.eval()
117
+ print("✅ Model loaded successfully.")
118
+ return model
119
+
120
+ # ---------------------------------------------------------
121
+ # 4. JSON CONVERSION LOGIC (Your Custom Logic)
122
+ # ---------------------------------------------------------
123
+ def convert_bio_to_structured_json(predictions):
124
+ structured_data = []
125
+ current_item = None
126
+ current_option_key = None
127
+ current_passage_buffer = []
128
+ current_text_buffer = []
129
+ first_question_started = False
130
+ last_entity_type = None
131
+ just_finished_i_option = False
132
+ is_in_new_passage = False
133
+
134
+ def finalize_passage_to_item(item, passage_buffer):
135
+ if passage_buffer:
136
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
137
+ if item.get('passage'): item['passage'] += ' ' + passage_text
138
+ else: item['passage'] = passage_text
139
+ passage_buffer.clear()
140
+
141
+ # Flatten predictions list if strictly page-separated
142
+ flat_predictions = []
143
+ for page in predictions:
144
+ flat_predictions.extend(page['data'])
145
+
146
+ for idx, item in enumerate(flat_predictions):
147
+ word = item['word']
148
+ label = item['predicted_label']
149
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
150
+ current_text_buffer.append(word)
151
+
152
+ previous_entity_type = last_entity_type
153
+ is_passage_label = (entity_type == 'PASSAGE')
154
+
155
+ if not first_question_started:
156
+ if label != 'B-QUESTION' and not is_passage_label:
157
+ just_finished_i_option = False
158
+ is_in_new_passage = False
159
+ continue
160
+ if is_passage_label:
161
+ current_passage_buffer.append(word)
162
+ last_entity_type = 'PASSAGE'
163
+ just_finished_i_option = False
164
+ is_in_new_passage = False
165
+ continue
166
+
167
+ if label == 'B-QUESTION':
168
+ if not first_question_started:
169
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
170
+ if header_text or current_passage_buffer:
171
+ metadata_item = {'type': 'METADATA', 'passage': ''}
172
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
173
+ if header_text: metadata_item['text'] = header_text
174
+ structured_data.append(metadata_item)
175
+ first_question_started = True
176
+ current_text_buffer = [word]
177
+
178
+ if current_item is not None:
179
+ finalize_passage_to_item(current_item, current_passage_buffer)
180
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
181
+ structured_data.append(current_item)
182
+ current_text_buffer = [word]
183
+
184
+ current_item = {
185
+ 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
186
+ }
187
+ current_option_key = None
188
+ last_entity_type = 'QUESTION'
189
+ just_finished_i_option = False
190
+ is_in_new_passage = False
191
+ continue
192
+
193
+ if current_item is not None:
194
+ if is_in_new_passage:
195
+ if 'new_passage' not in current_item: current_item['new_passage'] = word
196
+ else: current_item['new_passage'] += f' {word}'
197
+ if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
198
+ is_in_new_passage = False
199
+ if label.startswith(('B-', 'I-')): last_entity_type = entity_type
200
+ continue
201
+
202
+ is_in_new_passage = False
203
+ if label.startswith('B-'):
204
+ if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
205
+ finalize_passage_to_item(current_item, current_passage_buffer)
206
+ current_passage_buffer = []
207
+ last_entity_type = entity_type
208
+ if entity_type == 'PASSAGE':
209
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
210
+ current_item['new_passage'] = word
211
+ is_in_new_passage = True
212
+ else: current_passage_buffer.append(word)
213
+ elif entity_type == 'OPTION':
214
+ current_option_key = word
215
+ current_item['options'][current_option_key] = word
216
+ just_finished_i_option = False
217
+ elif entity_type == 'ANSWER':
218
+ current_item['answer'] = word
219
+ current_option_key = None
220
+ just_finished_i_option = False
221
+ elif entity_type == 'QUESTION':
222
+ current_item['question'] += f' {word}'
223
+ just_finished_i_option = False
224
+
225
+ elif label.startswith('I-'):
226
+ if entity_type == 'QUESTION': current_item['question'] += f' {word}'
227
+ elif entity_type == 'PASSAGE':
228
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
229
+ current_item['new_passage'] = word
230
+ is_in_new_passage = True
231
+ else:
232
+ if not current_passage_buffer: last_entity_type = 'PASSAGE'
233
+ current_passage_buffer.append(word)
234
+ elif entity_type == 'OPTION' and current_option_key is not None:
235
+ current_item['options'][current_option_key] += f' {word}'
236
+ just_finished_i_option = True
237
+ elif entity_type == 'ANSWER': current_item['answer'] += f' {word}'
238
+ just_finished_i_option = (entity_type == 'OPTION')
239
+
240
+ if current_item is not None:
241
+ finalize_passage_to_item(current_item, current_passage_buffer)
242
+ current_item['text'] = ' '.join(current_text_buffer).strip()
243
+ structured_data.append(current_item)
244
+
245
+ # Final Cleanup
246
+ for item in structured_data:
247
+ if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
248
+ if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
249
+
250
+ return structured_data
251
+
252
+ # ---------------------------------------------------------
253
+ # 5. INFERENCE PIPELINE
254
+ # ---------------------------------------------------------
255
+ def process_pdf(pdf_file):
256
+ if pdf_file is None:
257
+ return None, "⚠️ Please upload a PDF file."
258
+
259
+ try:
260
+ active_model = load_model()
261
+
262
+ # A. Extract Text and Boxes
263
+ extracted_pages = []
264
+ with pdfplumber.open(pdf_file.name) as pdf:
265
+ for page_idx, page in enumerate(pdf.pages):
266
+ width, height = page.width, page.height
267
+ words_data = page.extract_words()
268
+
269
+ page_tokens = []
270
+ page_bboxes = []
271
+
272
+ for w in words_data:
273
+ text = w['text']
274
+ # Normalize bbox to 0-1000 scale
275
+ x0 = int((w['x0'] / width) * 1000)
276
+ top = int((w['top'] / height) * 1000)
277
+ x1 = int((w['x1'] / width) * 1000)
278
+ bottom = int((w['bottom'] / height) * 1000)
279
+
280
+ # Safety clamp
281
+ box = [max(0, min(x0, 1000)), max(0, min(top, 1000)),
282
+ max(0, min(x1, 1000)), max(0, min(bottom, 1000))]
283
+
284
+ page_tokens.append(text)
285
+ page_bboxes.append(box)
286
+ extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes})
287
+
288
+ # B. Run Inference
289
+ raw_predictions = []
290
+ for page in extracted_pages:
291
+ tokens = page['tokens']
292
+ bboxes = page['bboxes']
293
+ if not tokens: continue
294
+
295
+ # Tokenize
296
+ encoding = tokenizer(
297
+ tokens,
298
+ boxes=bboxes,
299
+ return_tensors="pt",
300
+ padding="max_length",
301
+ truncation=True,
302
+ max_length=512,
303
+ return_offsets_mapping=True
304
+ )
305
+
306
+ input_ids = encoding.input_ids.to(device)
307
+ bbox = encoding.bbox.to(device)
308
+ attention_mask = encoding.attention_mask.to(device)
309
+
310
+ # Predict
311
+ with torch.no_grad():
312
+ # NOTE: If your hybrid model requires 'pixel_values',
313
+ # you will need to add image extraction logic above and pass it here.
314
+ preds = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
315
+
316
+ # Check if preds returns a tuple (loss, tags) or just tags
317
+ # The CRF implementation usually returns a list of lists of tags in viterbi_decode
318
+ pred_tags = preds[0] if isinstance(preds, tuple) else preds[0]
319
+ # Note: Standard CRF.viterbi_decode returns List[List[int]], so [0] gets the first batch item
320
+
321
+ # Alignment
322
+ word_ids = encoding.word_ids()
323
+ aligned_data = []
324
+ prev_word_idx = None
325
+
326
+ for i, word_idx in enumerate(word_ids):
327
+ if word_idx is None: continue
328
+ if word_idx != prev_word_idx:
329
+ # pred_tags is likely a list of ints.
330
+ # If pred_tags[i] fails, your max_length might be cutting off tags,
331
+ # or the model output shape differs from the token length.
332
+ if i < len(pred_tags):
333
+ label_id = pred_tags[i]
334
+ label_str = ID2LABEL.get(label_id, "O")
335
+ aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str})
336
+ prev_word_idx = word_idx
337
+ raw_predictions.append({"data": aligned_data})
338
+
339
+ # C. Convert to Structured JSON
340
+ final_json = convert_bio_to_structured_json(raw_predictions)
341
+
342
+ # Save output
343
+ output_filename = "structured_output.json"
344
+ with open(output_filename, "w", encoding="utf-8") as f:
345
+ json.dump(final_json, f, indent=2, ensure_ascii=False)
346
+
347
+ return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items."
348
+
349
+ except Exception as e:
350
+ import traceback
351
+ return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}"
352
+
353
+ # ---------------------------------------------------------
354
+ # 6. GRADIO INTERFACE
355
+ # ---------------------------------------------------------
356
+ iface = gr.Interface(
357
+ fn=process_pdf,
358
+ inputs=gr.File(label="Upload PDF", file_types=[".pdf"]),
359
+ outputs=[
360
+ gr.File(label="Download JSON Output"),
361
+ gr.Textbox(label="Status Log", lines=10)
362
+ ],
363
+ title="Hybrid Model Inference: PDF to JSON",
364
+ description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.",
365
+ flagging_mode="never"
366
+ )
367
+
368
+ if __name__ == "__main__":
369
+ iface.launch()
train.py DELETED
@@ -1,244 +0,0 @@
1
- import json
2
- import argparse
3
- import os
4
- import random
5
- import torch
6
- import torch.nn as nn
7
- from torch.utils.data import Dataset, DataLoader, random_split
8
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
- from TorchCRF import CRF
10
- from torch.optim import AdamW
11
- from tqdm import tqdm
12
- from sklearn.metrics import precision_recall_fscore_support
13
-
14
- # --- Configuration ---
15
- MAX_BBOX_DIMENSION = 1000
16
- MAX_SHIFT = 30
17
- AUGMENTATION_FACTOR = 1
18
- BASE_MODEL_ID = "heerjtdev/MLP_LayoutLM"
19
-
20
- # -------------------------
21
- # Step 1: Preprocessing
22
- # -------------------------
23
- def preprocess_labelstudio(input_path, output_path):
24
- with open(input_path, "r", encoding="utf-8") as f:
25
- data = json.load(f)
26
-
27
- processed = []
28
- print(f"🔄 Starting preprocessing of {len(data)} documents...")
29
-
30
- for item in data:
31
- words = item["data"]["original_words"]
32
- bboxes = item["data"]["original_bboxes"]
33
- labels = ["O"] * len(words)
34
-
35
- clamped_bboxes = []
36
- for bbox in bboxes:
37
- x_min, y_min, x_max, y_max = bbox
38
- new_x_min = max(0, min(x_min, 1000))
39
- new_y_min = max(0, min(y_min, 1000))
40
- new_x_max = max(0, min(x_max, 1000))
41
- new_y_max = max(0, min(y_max, 1000))
42
- if new_x_min > new_x_max: new_x_min = new_x_max
43
- if new_y_min > new_y_max: new_y_min = new_y_max
44
- clamped_bboxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
45
-
46
- if "annotations" in item:
47
- for ann in item["annotations"]:
48
- for res in ann["result"]:
49
- if "value" in res and "labels" in res["value"]:
50
- text = res["value"]["text"]
51
- tag = res["value"]["labels"][0]
52
- text_tokens = text.split()
53
- for i in range(len(words) - len(text_tokens) + 1):
54
- if words[i:i + len(text_tokens)] == text_tokens:
55
- labels[i] = f"B-{tag}"
56
- for j in range(1, len(text_tokens)):
57
- labels[i + j] = f"I-{tag}"
58
- break
59
-
60
- processed.append({"tokens": words, "labels": labels, "bboxes": clamped_bboxes})
61
-
62
- with open(output_path, "w", encoding="utf-8") as f:
63
- json.dump(processed, f, indent=2, ensure_ascii=False)
64
- return output_path
65
-
66
- # -------------------------
67
- # Step 1.5: Augmentation
68
- # -------------------------
69
- def translate_bbox(bbox, shift_x, shift_y):
70
- x_min, y_min, x_max, y_max = bbox
71
- new_x_min = max(0, min(x_min + shift_x, 1000))
72
- new_y_min = max(0, min(y_min + shift_y, 1000))
73
- new_x_max = max(0, min(x_max + shift_x, 1000))
74
- new_y_max = max(0, min(y_max + shift_y, 1000))
75
- return [new_x_min, new_y_min, new_x_max, new_y_max]
76
-
77
- def augment_sample(sample):
78
- shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
79
- shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
80
- new_sample = sample.copy()
81
- new_sample["bboxes"] = [translate_bbox(b, shift_x, shift_y) for b in sample["bboxes"]]
82
- return new_sample
83
-
84
- def augment_and_save_dataset(input_json_path, output_json_path):
85
- with open(input_json_path, 'r', encoding="utf-8") as f:
86
- training_data = json.load(f)
87
- augmented_data = []
88
- for original_sample in training_data:
89
- augmented_data.append(original_sample)
90
- for _ in range(AUGMENTATION_FACTOR):
91
- augmented_data.append(augment_sample(original_sample))
92
- with open(output_json_path, 'w', encoding="utf-8") as f:
93
- json.dump(augmented_data, f, indent=2, ensure_ascii=False)
94
- return output_json_path
95
-
96
- # -------------------------
97
- # Step 2: Dataset Class
98
- # -------------------------
99
- class LayoutDataset(Dataset):
100
- def __init__(self, json_path, tokenizer, label2id, max_len=512):
101
- with open(json_path, "r", encoding="utf-8") as f:
102
- self.data = json.load(f)
103
- self.tokenizer = tokenizer
104
- self.label2id = label2id
105
- self.max_len = max_len
106
-
107
- def __len__(self):
108
- return len(self.data)
109
-
110
- def __getitem__(self, idx):
111
- item = self.data[idx]
112
- words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
113
- encodings = self.tokenizer(words, boxes=bboxes, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
114
- word_ids = encodings.word_ids(batch_index=0)
115
- label_ids = []
116
- for word_id in word_ids:
117
- if word_id is None:
118
- label_ids.append(self.label2id["O"])
119
- else:
120
- label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
121
- encodings["labels"] = torch.tensor(label_ids)
122
- return {key: val.squeeze(0) for key, val in encodings.items()}
123
-
124
- # -------------------------
125
- # Step 3: Model Architecture (Non-Linear Head)
126
- # -------------------------
127
-
128
- class LayoutLMv3CRF(nn.Module):
129
- def __init__(self, num_labels):
130
- super().__init__()
131
- # Initializing from scratch (Base weights only)
132
- print(f"🔄 Initializing backbone from {BASE_MODEL_ID}...")
133
- self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
134
-
135
- hidden_size = self.layoutlm.config.hidden_size
136
-
137
- # NON-LINEAR MLP HEAD
138
- # Replacing the simple Linear layer with a deeper architecture
139
- self.classifier = nn.Sequential(
140
- nn.Linear(hidden_size, hidden_size),
141
- nn.GELU(), # Non-linear activation
142
- nn.LayerNorm(hidden_size), # Stability for training from scratch
143
- nn.Dropout(0.1),
144
- nn.Linear(hidden_size, num_labels)
145
- )
146
-
147
- self.crf = CRF(num_labels)
148
-
149
- def forward(self, input_ids, bbox, attention_mask, labels=None):
150
- outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
151
- sequence_output = outputs.last_hidden_state
152
-
153
- # Pass through the new non-linear head
154
- emissions = self.classifier(sequence_output)
155
-
156
- if labels is not None:
157
- log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
158
- return -log_likelihood.mean()
159
- else:
160
- return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
161
-
162
- # -------------------------
163
- # Step 4: Training + Evaluation
164
- # -------------------------
165
- def train_one_epoch(model, dataloader, optimizer, device):
166
- model.train()
167
- total_loss = 0
168
- for batch in tqdm(dataloader, desc="Training"):
169
- batch = {k: v.to(device) for k, v in batch.items()}
170
- labels = batch.pop("labels")
171
- optimizer.zero_grad()
172
- loss = model(**batch, labels=labels)
173
- loss.backward()
174
- optimizer.step()
175
- total_loss += loss.item()
176
- return total_loss / len(dataloader)
177
-
178
- def evaluate(model, dataloader, device, id2label):
179
- model.eval()
180
- all_preds, all_labels = [], []
181
- with torch.no_grad():
182
- for batch in tqdm(dataloader, desc="Evaluating"):
183
- batch = {k: v.to(device) for k, v in batch.items()}
184
- labels = batch.pop("labels").cpu().numpy()
185
- preds = model(**batch)
186
- for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
187
- valid = mask == 1
188
- l_valid = l[valid].tolist()
189
- all_labels.extend(l_valid)
190
- all_preds.extend(p[:len(l_valid)])
191
- precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
192
- return precision, recall, f1
193
-
194
- # -------------------------
195
- # Step 5: Main Execution
196
- # -------------------------
197
- def main(args):
198
- labels = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"]
199
- label2id = {l: i for i, l in enumerate(labels)}
200
- id2label = {i: l for l, i in label2id.items()}
201
-
202
- TEMP_DIR = "temp_intermediate_files"
203
- os.makedirs(TEMP_DIR, exist_ok=True)
204
-
205
- # 1. Preprocess & Augment
206
- initial_json = os.path.join(TEMP_DIR, "data_bio.json")
207
- preprocess_labelstudio(args.input, initial_json)
208
- augmented_json = os.path.join(TEMP_DIR, "data_aug.json")
209
- final_data_path = augment_and_save_dataset(initial_json, augmented_json)
210
-
211
- # 2. Setup Data
212
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
213
- dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
214
- val_size = int(0.2 * len(dataset))
215
- train_dataset, val_dataset = random_split(dataset, [len(dataset) - val_size, val_size])
216
-
217
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
218
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
219
-
220
- # 3. Model
221
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
- model = LayoutLMv3CRF(num_labels=len(labels)).to(device)
223
- optimizer = AdamW(model.parameters(), lr=args.lr)
224
-
225
- # 4. Loop
226
- for epoch in range(args.epochs):
227
- loss = train_one_epoch(model, train_loader, optimizer, device)
228
- p, r, f1 = evaluate(model, val_loader, device, id2label)
229
- print(f"Epoch {epoch+1} | Loss: {loss:.4f} | F1: {f1:.3f}")
230
-
231
- ckpt_path = "checkpoints/layoutlmv3_nonlinear_scratch.pth"
232
- os.makedirs("checkpoints", exist_ok=True)
233
- torch.save(model.state_dict(), ckpt_path)
234
-
235
- if __name__ == "__main__":
236
- parser = argparse.ArgumentParser()
237
- parser.add_argument("--mode", type=str, default="train")
238
- parser.add_argument("--input", type=str, required=True)
239
- parser.add_argument("--batch_size", type=int, default=4)
240
- parser.add_argument("--epochs", type=int, default=10) # Increased for scratch training
241
- parser.add_argument("--lr", type=float, default=2e-5)
242
- parser.add_argument("--max_len", type=int, default=512)
243
- args = parser.parse_args()
244
- main(args)