aagamjtdev commited on
Commit
ee11b08
·
1 Parent(s): caca294

correction

Browse files
Files changed (1) hide show
  1. HF_LayoutLM_with_Passage.py +7 -7
HF_LayoutLM_with_Passage.py CHANGED
@@ -203,8 +203,8 @@ class LayoutDataset(Dataset):
203
  class LayoutLMv3CRF(nn.Module):
204
  def __init__(self, model_name, num_labels):
205
  super().__init__()
206
- # self.layoutlm = LayoutLMv3Model.from_pretrained(model_name)
207
- self.layoutlm = LayoutLMv3Model.from_pretrained("heerjtdev/edugenius")
208
  self.dropout = nn.Dropout(0.1)
209
  self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
210
  self.crf = CRF(num_labels)
@@ -302,9 +302,9 @@ def main(args):
302
 
303
  # 3. Load and split augmented dataset
304
  print("\n--- START PHASE: MODEL/DATASET SETUP ---")
305
- MODEL_ID = "heerjtdev/edugenius"
306
- # tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
307
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_ID)
308
 
309
  dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
310
  val_size = int(0.2 * len(dataset))
@@ -320,8 +320,8 @@ def main(args):
320
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
321
  print(f"Using device: {device}")
322
  # Num_labels is based on the updated 'labels' list
323
- # model = LayoutLMv3CRF("microsoft/layoutlmv3-base", num_labels=len(labels)).to(device)
324
- model = LayoutLMv3CRF(MODEL_ID, num_labels=len(labels)).to(device)
325
  ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
326
  os.makedirs("checkpoints", exist_ok=True)
327
  if os.path.exists(ckpt_path):
 
203
  class LayoutLMv3CRF(nn.Module):
204
  def __init__(self, model_name, num_labels):
205
  super().__init__()
206
+ self.layoutlm = LayoutLMv3Model.from_pretrained(model_name)
207
+ # self.layoutlm = LayoutLMv3Model.from_pretrained("heerjtdev/edugenius")
208
  self.dropout = nn.Dropout(0.1)
209
  self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
210
  self.crf = CRF(num_labels)
 
302
 
303
  # 3. Load and split augmented dataset
304
  print("\n--- START PHASE: MODEL/DATASET SETUP ---")
305
+ #MODEL_ID = "heerjtdev/edugenius"
306
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
307
+ #tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_ID)
308
 
309
  dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
310
  val_size = int(0.2 * len(dataset))
 
320
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
321
  print(f"Using device: {device}")
322
  # Num_labels is based on the updated 'labels' list
323
+ model = LayoutLMv3CRF("microsoft/layoutlmv3-base", num_labels=len(labels)).to(device)
324
+ # model = LayoutLMv3CRF(MODEL_ID, num_labels=len(labels)).to(device)
325
  ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
326
  os.makedirs("checkpoints", exist_ok=True)
327
  if os.path.exists(ckpt_path):