heerjtdev commited on
Commit
8e9da7d
Β·
verified Β·
1 Parent(s): 9e890f3

Update train_hybrid.py

Browse files
Files changed (1) hide show
  1. train_hybrid.py +17 -47
train_hybrid.py CHANGED
@@ -14,15 +14,13 @@ import numpy as np
14
  try:
15
  from TorchCRF import CRF
16
  except ImportError:
17
- print("❌ Error: 'TorchCRF' not found")
18
  exit()
19
 
20
  # --- Configuration ---
21
- # We use the base model for the backbone
22
  BASE_MODEL_ID = "microsoft/layoutlmv3-base"
23
  MAX_LEN = 512
24
 
25
- # Labels from your BiLSTM script
26
  LABELS = [
27
  "O",
28
  "B-QUESTION", "I-QUESTION",
@@ -43,48 +41,39 @@ class LayoutLMv3BiLSTMCRF(nn.Module):
43
  super().__init__()
44
  print(f"πŸ—οΈ Initializing Hybrid Model: LayoutLMv3 + BiLSTM + CRF...")
45
 
46
- # 1. Backbone: LayoutLMv3 (Replaces Word Emb + CharCNN + Spatial Features)
47
  self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
48
-
49
- # LayoutLMv3-base hidden size is 768
50
  transformer_output_size = self.layoutlm.config.hidden_size
51
 
52
  # 2. Middle: Bi-LSTM
53
- # Takes the 768 vectors from Transformer and models sequence
54
  self.lstm = nn.LSTM(
55
  input_size=transformer_output_size,
56
  hidden_size=hidden_dim,
57
- num_layers=2, # Stacked LSTM for depth
58
  bidirectional=True,
59
  batch_first=True,
60
  dropout=0.1
61
  )
62
 
63
  # 3. Head: Linear Projection
64
- # Input is hidden_dim * 2 (because bidirectional)
65
  self.classifier = nn.Linear(hidden_dim * 2, num_labels)
66
 
67
  # 4. Decoder: CRF
68
  self.crf = CRF(num_labels)
69
 
70
  def forward(self, input_ids, bbox, attention_mask, labels=None):
71
- # Step A: Get Contextual Embeddings from LayoutLM
72
- # outputs.last_hidden_state shape: (Batch, Seq_Len, 768)
73
  outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
74
  sequence_output = outputs.last_hidden_state
75
 
76
- # Step B: Pass through Bi-LSTM
77
- # lstm_output shape: (Batch, Seq_Len, hidden_dim * 2)
78
  lstm_output, _ = self.lstm(sequence_output)
79
 
80
- # Step C: Project to Tag Space
81
- # emissions shape: (Batch, Seq_Len, num_labels)
82
  emissions = self.classifier(lstm_output)
83
 
84
- # Step D: CRF Loss or Decoding
85
  if labels is not None:
86
- # We must use the attention_mask so CRF doesn't train on padding tokens
87
- # Returns negative log likelihood
88
  log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
89
  return -log_likelihood.mean()
90
  else:
@@ -105,11 +94,9 @@ class LayoutDataset(Dataset):
105
  print(f"πŸ”„ Preprocessing {len(data)} documents...")
106
 
107
  for item in data:
108
- # Handle Label Studio JSON format
109
  if "data" in item:
110
  words = item["data"].get("original_words", [])
111
  bboxes = item["data"].get("original_bboxes", [])
112
- # Handle missing labels gracefully
113
  labels = item.get("labels", ["O"] * len(words))
114
  else:
115
  words = item.get("tokens", [])
@@ -118,7 +105,7 @@ class LayoutDataset(Dataset):
118
 
119
  if not words: continue
120
 
121
- # Normalize bboxes to 0-1000
122
  norm_bboxes = []
123
  for b in bboxes:
124
  x0, y0, x1, y1 = b
@@ -129,20 +116,18 @@ class LayoutDataset(Dataset):
129
  max(0, min(1000, int(y1)))
130
  ])
131
 
132
- # --- THE FIX IS HERE ---
133
- # 1. Use 'text=' keyword argument
134
- # 2. Ensure 'is_split_into_words=True' is passed explicitly
135
  encoding = self.tokenizer(
136
- text=words, # <--- Changed from positional to keyword
137
  boxes=norm_bboxes,
138
  padding="max_length",
139
  truncation=True,
140
  max_length=self.max_len,
141
- is_split_into_words=True, # This tells it 'words' is a list of strings
142
  return_tensors="pt"
143
  )
144
 
145
- # Align labels with subtokens
146
  word_ids = encoding.word_ids(batch_index=0)
147
  label_ids = []
148
  for word_id in word_ids:
@@ -162,6 +147,7 @@ class LayoutDataset(Dataset):
162
 
163
  def __getitem__(self, idx):
164
  return self.processed_data[idx]
 
165
  # -------------------------
166
  # 3. Training Function
167
  # -------------------------
@@ -169,22 +155,15 @@ def train_one_epoch(model, dataloader, optimizer, device):
169
  model.train()
170
  total_loss = 0
171
  for batch in tqdm(dataloader, desc="Training"):
172
- # Move batch to device
173
  input_ids = batch["input_ids"].to(device)
174
  bbox = batch["bbox"].to(device)
175
  attention_mask = batch["attention_mask"].to(device)
176
  labels = batch["labels"].to(device)
177
 
178
  optimizer.zero_grad()
179
-
180
- # Forward pass (Auto-calculates CRF loss inside model)
181
  loss = model(input_ids, bbox, attention_mask, labels=labels)
182
-
183
  loss.backward()
184
-
185
- # Gradient clipping (Important for LSTM stability)
186
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
187
-
188
  optimizer.step()
189
  total_loss += loss.item()
190
 
@@ -194,17 +173,12 @@ def train_one_epoch(model, dataloader, optimizer, device):
194
  # 4. Main Execution
195
  # -------------------------
196
  def main(args):
197
- # Setup Device
198
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199
  print(f"βš™οΈ Using device: {device}")
200
 
201
- # Initialize Tokenizer
202
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
203
-
204
- # Load Dataset
205
  dataset = LayoutDataset(args.input, tokenizer, max_len=args.max_len)
206
 
207
- # Train/Val Split
208
  train_size = int(0.9 * len(dataset))
209
  val_size = len(dataset) - train_size
210
  train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
@@ -212,14 +186,11 @@ def main(args):
212
  train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
213
  val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
214
 
215
- # Initialize Hybrid Model
216
  model = LayoutLMv3BiLSTMCRF(num_labels=len(LABELS)).to(device)
217
 
218
- # Optimization
219
- # We use different learning rates: lower for transformer, higher for LSTM/CRF head
220
  optimizer = AdamW([
221
- {'params': model.layoutlm.parameters(), 'lr': 2e-5}, # Low LR for backbone
222
- {'params': model.lstm.parameters(), 'lr': 1e-4}, # Higher LR for LSTM
223
  {'params': model.classifier.parameters(), 'lr': 1e-4},
224
  {'params': model.crf.parameters(), 'lr': 1e-4}
225
  ])
@@ -230,7 +201,6 @@ def main(args):
230
  loss = train_one_epoch(model, train_loader, optimizer, device)
231
  print(f"Epoch {epoch+1}/{args.epochs} | Loss: {loss:.4f}")
232
 
233
- # Save Checkpoint
234
  os.makedirs("checkpoints", exist_ok=True)
235
  save_path = "checkpoints/layoutlmv3_bilstm_crf_hybrid.pth"
236
  torch.save(model.state_dict(), save_path)
@@ -241,8 +211,8 @@ if __name__ == "__main__":
241
  parser.add_argument("--input", type=str, required=True, help="Path to unified JSON data")
242
  parser.add_argument("--batch_size", type=int, default=4)
243
  parser.add_argument("--epochs", type=int, default=5)
244
- parser.add_argument("--lr", type=float, default=2e-5) # Base LR
245
  parser.add_argument("--max_len", type=int, default=512)
246
- parser.add_argument("--mode", type=str, default="train") # Kept for compatibility with Gradio
247
  args = parser.parse_args()
248
  main(args)
 
14
  try:
15
  from TorchCRF import CRF
16
  except ImportError:
17
+ print("❌ Error: 'TorchCRF' not found. Install via: pip install pytorch-crf")
18
  exit()
19
 
20
  # --- Configuration ---
 
21
  BASE_MODEL_ID = "microsoft/layoutlmv3-base"
22
  MAX_LEN = 512
23
 
 
24
  LABELS = [
25
  "O",
26
  "B-QUESTION", "I-QUESTION",
 
41
  super().__init__()
42
  print(f"πŸ—οΈ Initializing Hybrid Model: LayoutLMv3 + BiLSTM + CRF...")
43
 
44
+ # 1. Backbone: LayoutLMv3
45
  self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
 
 
46
  transformer_output_size = self.layoutlm.config.hidden_size
47
 
48
  # 2. Middle: Bi-LSTM
 
49
  self.lstm = nn.LSTM(
50
  input_size=transformer_output_size,
51
  hidden_size=hidden_dim,
52
+ num_layers=2,
53
  bidirectional=True,
54
  batch_first=True,
55
  dropout=0.1
56
  )
57
 
58
  # 3. Head: Linear Projection
 
59
  self.classifier = nn.Linear(hidden_dim * 2, num_labels)
60
 
61
  # 4. Decoder: CRF
62
  self.crf = CRF(num_labels)
63
 
64
  def forward(self, input_ids, bbox, attention_mask, labels=None):
65
+ # Step A: LayoutLMv3
 
66
  outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
67
  sequence_output = outputs.last_hidden_state
68
 
69
+ # Step B: Bi-LSTM
 
70
  lstm_output, _ = self.lstm(sequence_output)
71
 
72
+ # Step C: Projection
 
73
  emissions = self.classifier(lstm_output)
74
 
75
+ # Step D: CRF
76
  if labels is not None:
 
 
77
  log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
78
  return -log_likelihood.mean()
79
  else:
 
94
  print(f"πŸ”„ Preprocessing {len(data)} documents...")
95
 
96
  for item in data:
 
97
  if "data" in item:
98
  words = item["data"].get("original_words", [])
99
  bboxes = item["data"].get("original_bboxes", [])
 
100
  labels = item.get("labels", ["O"] * len(words))
101
  else:
102
  words = item.get("tokens", [])
 
105
 
106
  if not words: continue
107
 
108
+ # Normalize bboxes
109
  norm_bboxes = []
110
  for b in bboxes:
111
  x0, y0, x1, y1 = b
 
116
  max(0, min(1000, int(y1)))
117
  ])
118
 
119
+ # --- KEY FIX IS HERE ---
120
+ # using text=words explicitly fixes the positional argument error
 
121
  encoding = self.tokenizer(
122
+ text=words,
123
  boxes=norm_bboxes,
124
  padding="max_length",
125
  truncation=True,
126
  max_length=self.max_len,
127
+ is_split_into_words=True,
128
  return_tensors="pt"
129
  )
130
 
 
131
  word_ids = encoding.word_ids(batch_index=0)
132
  label_ids = []
133
  for word_id in word_ids:
 
147
 
148
  def __getitem__(self, idx):
149
  return self.processed_data[idx]
150
+
151
  # -------------------------
152
  # 3. Training Function
153
  # -------------------------
 
155
  model.train()
156
  total_loss = 0
157
  for batch in tqdm(dataloader, desc="Training"):
 
158
  input_ids = batch["input_ids"].to(device)
159
  bbox = batch["bbox"].to(device)
160
  attention_mask = batch["attention_mask"].to(device)
161
  labels = batch["labels"].to(device)
162
 
163
  optimizer.zero_grad()
 
 
164
  loss = model(input_ids, bbox, attention_mask, labels=labels)
 
165
  loss.backward()
 
 
166
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 
167
  optimizer.step()
168
  total_loss += loss.item()
169
 
 
173
  # 4. Main Execution
174
  # -------------------------
175
  def main(args):
 
176
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177
  print(f"βš™οΈ Using device: {device}")
178
 
 
179
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
 
 
180
  dataset = LayoutDataset(args.input, tokenizer, max_len=args.max_len)
181
 
 
182
  train_size = int(0.9 * len(dataset))
183
  val_size = len(dataset) - train_size
184
  train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
 
186
  train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
187
  val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
188
 
 
189
  model = LayoutLMv3BiLSTMCRF(num_labels=len(LABELS)).to(device)
190
 
 
 
191
  optimizer = AdamW([
192
+ {'params': model.layoutlm.parameters(), 'lr': 2e-5},
193
+ {'params': model.lstm.parameters(), 'lr': 1e-4},
194
  {'params': model.classifier.parameters(), 'lr': 1e-4},
195
  {'params': model.crf.parameters(), 'lr': 1e-4}
196
  ])
 
201
  loss = train_one_epoch(model, train_loader, optimizer, device)
202
  print(f"Epoch {epoch+1}/{args.epochs} | Loss: {loss:.4f}")
203
 
 
204
  os.makedirs("checkpoints", exist_ok=True)
205
  save_path = "checkpoints/layoutlmv3_bilstm_crf_hybrid.pth"
206
  torch.save(model.state_dict(), save_path)
 
211
  parser.add_argument("--input", type=str, required=True, help="Path to unified JSON data")
212
  parser.add_argument("--batch_size", type=int, default=4)
213
  parser.add_argument("--epochs", type=int, default=5)
214
+ parser.add_argument("--lr", type=float, default=2e-5)
215
  parser.add_argument("--max_len", type=int, default=512)
216
+ parser.add_argument("--mode", type=str, default="train")
217
  args = parser.parse_args()
218
  main(args)