heerjtdev commited on
Commit
9e890f3
·
verified ·
1 Parent(s): 62ffe2c

Update train_hybrid.py

Browse files
Files changed (1) hide show
  1. train_hybrid.py +9 -11
train_hybrid.py CHANGED
@@ -91,7 +91,7 @@ class LayoutLMv3BiLSTMCRF(nn.Module):
91
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
92
 
93
  # -------------------------
94
- # 2. Data Processing
95
  # -------------------------
96
  class LayoutDataset(Dataset):
97
  def __init__(self, json_path, tokenizer, max_len=512):
@@ -109,21 +109,18 @@ class LayoutDataset(Dataset):
109
  if "data" in item:
110
  words = item["data"].get("original_words", [])
111
  bboxes = item["data"].get("original_bboxes", [])
112
- # If labels aren't pre-processed, you might need your conversion logic here.
113
- # Assuming the JSON input already has word-aligned labels or we create dummy ones
114
  labels = item.get("labels", ["O"] * len(words))
115
  else:
116
- # Fallback or generic format
117
  words = item.get("tokens", [])
118
  bboxes = item.get("bboxes", [])
119
  labels = item.get("labels", [])
120
 
121
  if not words: continue
122
 
123
- # Normalize bboxes to 0-1000 if not already
124
  norm_bboxes = []
125
  for b in bboxes:
126
- # Simple clamping 0-1000
127
  x0, y0, x1, y1 = b
128
  norm_bboxes.append([
129
  max(0, min(1000, int(x0))),
@@ -132,14 +129,16 @@ class LayoutDataset(Dataset):
132
  max(0, min(1000, int(y1)))
133
  ])
134
 
135
- # Tokenize
 
 
136
  encoding = self.tokenizer(
137
- words,
138
  boxes=norm_bboxes,
139
  padding="max_length",
140
  truncation=True,
141
  max_length=self.max_len,
142
- is_split_into_words=True,
143
  return_tensors="pt"
144
  )
145
 
@@ -148,7 +147,7 @@ class LayoutDataset(Dataset):
148
  label_ids = []
149
  for word_id in word_ids:
150
  if word_id is None:
151
- label_ids.append(LABEL2ID["O"]) # Pad/Special tokens are O
152
  elif word_id < len(labels):
153
  label_ids.append(LABEL2ID.get(labels[word_id], LABEL2ID["O"]))
154
  else:
@@ -163,7 +162,6 @@ class LayoutDataset(Dataset):
163
 
164
  def __getitem__(self, idx):
165
  return self.processed_data[idx]
166
-
167
  # -------------------------
168
  # 3. Training Function
169
  # -------------------------
 
91
  return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
92
 
93
  # -------------------------
94
+ # 2. Data Processing (FIXED)
95
  # -------------------------
96
  class LayoutDataset(Dataset):
97
  def __init__(self, json_path, tokenizer, max_len=512):
 
109
  if "data" in item:
110
  words = item["data"].get("original_words", [])
111
  bboxes = item["data"].get("original_bboxes", [])
112
+ # Handle missing labels gracefully
 
113
  labels = item.get("labels", ["O"] * len(words))
114
  else:
 
115
  words = item.get("tokens", [])
116
  bboxes = item.get("bboxes", [])
117
  labels = item.get("labels", [])
118
 
119
  if not words: continue
120
 
121
+ # Normalize bboxes to 0-1000
122
  norm_bboxes = []
123
  for b in bboxes:
 
124
  x0, y0, x1, y1 = b
125
  norm_bboxes.append([
126
  max(0, min(1000, int(x0))),
 
129
  max(0, min(1000, int(y1)))
130
  ])
131
 
132
+ # --- THE FIX IS HERE ---
133
+ # 1. Use 'text=' keyword argument
134
+ # 2. Ensure 'is_split_into_words=True' is passed explicitly
135
  encoding = self.tokenizer(
136
+ text=words, # <--- Changed from positional to keyword
137
  boxes=norm_bboxes,
138
  padding="max_length",
139
  truncation=True,
140
  max_length=self.max_len,
141
+ is_split_into_words=True, # This tells it 'words' is a list of strings
142
  return_tensors="pt"
143
  )
144
 
 
147
  label_ids = []
148
  for word_id in word_ids:
149
  if word_id is None:
150
+ label_ids.append(LABEL2ID["O"])
151
  elif word_id < len(labels):
152
  label_ids.append(LABEL2ID.get(labels[word_id], LABEL2ID["O"]))
153
  else:
 
162
 
163
  def __getitem__(self, idx):
164
  return self.processed_data[idx]
 
165
  # -------------------------
166
  # 3. Training Function
167
  # -------------------------