Soumyajit Ghosh commited on
Commit
f0e14bb
·
unverified ·
1 Parent(s): b99270c

Use environment variable for SROIE data path

Browse files
Files changed (1) hide show
  1. train_layoutlm.py +3 -2
train_layoutlm.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from tqdm import tqdm
7
  from seqeval.metrics import f1_score, precision_score, recall_score
8
  from pathlib import Path
 
9
 
10
  # --- 1. Global Configuration & Label Mapping ---
11
  print("Setting up configuration...")
@@ -15,7 +16,7 @@ label2id = {label: idx for idx, label in enumerate(label_list)}
15
  id2label = {idx: label for idx, label in enumerate(label_list)}
16
 
17
  MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
18
- SROIE_DATA_PATH = "C:\\Users\\Soumyajit Ghosh\\Downloads\\sroie\\sroie" # Make sure this path is correct
19
 
20
  # --- 2. PyTorch Dataset Class ---
21
  class SROIEDataset(Dataset):
@@ -182,4 +183,4 @@ def train():
182
 
183
 
184
  if __name__ == '__main__':
185
- train()
 
6
  from tqdm import tqdm
7
  from seqeval.metrics import f1_score, precision_score, recall_score
8
  from pathlib import Path
9
+ import os
10
 
11
  # --- 1. Global Configuration & Label Mapping ---
12
  print("Setting up configuration...")
 
16
  id2label = {idx: label for idx, label in enumerate(label_list)}
17
 
18
  MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
19
+ SROIE_DATA_PATH = os.getenv("SROIE_DATA_PATH", os.path.join("data", "sroie"))
20
 
21
  # --- 2. PyTorch Dataset Class ---
22
  class SROIEDataset(Dataset):
 
183
 
184
 
185
  if __name__ == '__main__':
186
+ train()