heerjtdev commited on
Commit
d5442f4
·
verified ·
1 Parent(s): 8e9da7d

Update train_hybrid.py

Browse files
Files changed (1) hide show
  1. train_hybrid.py +80 -34
train_hybrid.py CHANGED
@@ -80,7 +80,7 @@ class LayoutLMv3BiLSTMCRF(nn.Module):
80
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
81
 
82
  # -------------------------
83
- # 2. Data Processing (FIXED)
84
  # -------------------------
85
  class LayoutDataset(Dataset):
86
  def __init__(self, json_path, tokenizer, max_len=512):
@@ -91,7 +91,13 @@ class LayoutDataset(Dataset):
91
  self.max_len = max_len
92
  self.processed_data = []
93
 
94
- print(f"🔄 Preprocessing {len(data)} documents...")
 
 
 
 
 
 
95
 
96
  for item in data:
97
  if "data" in item:
@@ -105,41 +111,81 @@ class LayoutDataset(Dataset):
105
 
106
  if not words: continue
107
 
108
- # Normalize bboxes
109
- norm_bboxes = []
110
- for b in bboxes:
111
- x0, y0, x1, y1 = b
112
- norm_bboxes.append([
113
- max(0, min(1000, int(x0))),
114
- max(0, min(1000, int(y0))),
115
- max(0, min(1000, int(x1))),
116
- max(0, min(1000, int(y1)))
117
- ])
118
-
119
- # --- KEY FIX IS HERE ---
120
- # using text=words explicitly fixes the positional argument error
121
- encoding = self.tokenizer(
122
- text=words,
123
- boxes=norm_bboxes,
124
- padding="max_length",
125
- truncation=True,
126
- max_length=self.max_len,
127
- is_split_into_words=True,
128
- return_tensors="pt"
129
- )
130
 
131
- word_ids = encoding.word_ids(batch_index=0)
132
- label_ids = []
133
- for word_id in word_ids:
134
- if word_id is None:
135
- label_ids.append(LABEL2ID["O"])
136
- elif word_id < len(labels):
137
- label_ids.append(LABEL2ID.get(labels[word_id], LABEL2ID["O"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  else:
139
- label_ids.append(LABEL2ID["O"])
 
140
 
141
- item_dict = {key: val.squeeze(0) for key, val in encoding.items()}
142
- item_dict["labels"] = torch.tensor(label_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  self.processed_data.append(item_dict)
144
 
145
  def __len__(self):
 
80
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
81
 
82
  # -------------------------
83
+ # 2. Data Processing (MANUAL ALIGNMENT FIX)
84
  # -------------------------
85
  class LayoutDataset(Dataset):
86
  def __init__(self, json_path, tokenizer, max_len=512):
 
91
  self.max_len = max_len
92
  self.processed_data = []
93
 
94
+ # Get special token IDs
95
+ self.cls_token_id = tokenizer.cls_token_id
96
+ self.sep_token_id = tokenizer.sep_token_id
97
+ self.pad_token_id = tokenizer.pad_token_id
98
+ self.unk_token_id = tokenizer.unk_token_id
99
+
100
+ print(f"🔄 Preprocessing {len(data)} documents (Manual Alignment Mode)...")
101
 
102
  for item in data:
103
  if "data" in item:
 
111
 
112
  if not words: continue
113
 
114
+ # 1. Initialize with [CLS]
115
+ input_ids = [self.cls_token_id]
116
+ final_bboxes = [[0, 0, 0, 0]]
117
+ label_ids = [LABEL2ID["O"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # 2. Iterate word by word
120
+ for word, box, label_str in zip(words, bboxes, labels):
121
+ # Clamp bbox 0-1000
122
+ clamped_box = [
123
+ max(0, min(1000, int(box[0]))),
124
+ max(0, min(1000, int(box[1]))),
125
+ max(0, min(1000, int(box[2]))),
126
+ max(0, min(1000, int(box[3])))
127
+ ]
128
+
129
+ # Tokenize current word
130
+ word_tokens = tokenizer.tokenize(word)
131
+ if not word_tokens: continue # Skip empty/weird tokens
132
+
133
+ # Convert to IDs
134
+ word_sub_ids = tokenizer.convert_tokens_to_ids(word_tokens)
135
+
136
+ # Add to lists
137
+ input_ids.extend(word_sub_ids)
138
+
139
+ # Expand bbox to match number of sub-tokens
140
+ final_bboxes.extend([clamped_box] * len(word_sub_ids))
141
+
142
+ # Handle BIO Labels for sub-tokens
143
+ # First sub-token gets the B- tag (if applicable), others get I- tag
144
+ current_label_id = LABEL2ID.get(label_str, LABEL2ID["O"])
145
+
146
+ if label_str.startswith("B-"):
147
+ # Logic: First subtoken is B-X, rest are I-X
148
+ i_tag_str = "I-" + label_str[2:]
149
+ i_tag_id = LABEL2ID.get(i_tag_str, LABEL2ID["O"])
150
+
151
+ # First subtoken = Original B- tag
152
+ label_ids.append(current_label_id)
153
+ # Remaining subtokens = I- tag
154
+ label_ids.extend([i_tag_id] * (len(word_sub_ids) - 1))
155
  else:
156
+ # If it's O or I-X, just copy it to all subtokens
157
+ label_ids.extend([current_label_id] * len(word_sub_ids))
158
 
159
+ # 3. Truncate if too long (account for [SEP])
160
+ if len(input_ids) > self.max_len - 1:
161
+ input_ids = input_ids[:self.max_len - 1]
162
+ final_bboxes = final_bboxes[:self.max_len - 1]
163
+ label_ids = label_ids[:self.max_len - 1]
164
+
165
+ # 4. Add [SEP]
166
+ input_ids.append(self.sep_token_id)
167
+ final_bboxes.append([0, 0, 0, 0])
168
+ label_ids.append(LABEL2ID["O"])
169
+
170
+ # 5. Create Attention Mask
171
+ attention_mask = [1] * len(input_ids)
172
+
173
+ # 6. Pad to max_len
174
+ padding_length = self.max_len - len(input_ids)
175
+ if padding_length > 0:
176
+ input_ids += [self.pad_token_id] * padding_length
177
+ final_bboxes += [[0, 0, 0, 0]] * padding_length
178
+ label_ids += [LABEL2ID["O"]] * padding_length
179
+ attention_mask += [0] * padding_length
180
+
181
+ # 7. Convert to Tensors
182
+ item_dict = {
183
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
184
+ "bbox": torch.tensor(final_bboxes, dtype=torch.long),
185
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
186
+ "labels": torch.tensor(label_ids, dtype=torch.long)
187
+ }
188
+
189
  self.processed_data.append(item_dict)
190
 
191
  def __len__(self):