Spaces:
Sleeping
Sleeping
Soumyajit Ghosh
commited on
Use environment variable for SROIE data path
Browse files- train_layoutlm.py +3 -2
train_layoutlm.py
CHANGED
|
@@ -6,6 +6,7 @@ from PIL import Image
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from seqeval.metrics import f1_score, precision_score, recall_score
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
# --- 1. Global Configuration & Label Mapping ---
|
| 11 |
print("Setting up configuration...")
|
|
@@ -15,7 +16,7 @@ label2id = {label: idx for idx, label in enumerate(label_list)}
|
|
| 15 |
id2label = {idx: label for idx, label in enumerate(label_list)}
|
| 16 |
|
| 17 |
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 18 |
-
SROIE_DATA_PATH = "
|
| 19 |
|
| 20 |
# --- 2. PyTorch Dataset Class ---
|
| 21 |
class SROIEDataset(Dataset):
|
|
@@ -182,4 +183,4 @@ def train():
|
|
| 182 |
|
| 183 |
|
| 184 |
if __name__ == '__main__':
|
| 185 |
-
train()
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
from seqeval.metrics import f1_score, precision_score, recall_score
|
| 8 |
from pathlib import Path
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
# --- 1. Global Configuration & Label Mapping ---
|
| 12 |
print("Setting up configuration...")
|
|
|
|
| 16 |
id2label = {idx: label for idx, label in enumerate(label_list)}
|
| 17 |
|
| 18 |
MODEL_CHECKPOINT = "microsoft/layoutlmv3-base"
|
| 19 |
+
SROIE_DATA_PATH = os.getenv("SROIE_DATA_PATH", os.path.join("data", "sroie"))
|
| 20 |
|
| 21 |
# --- 2. PyTorch Dataset Class ---
|
| 22 |
class SROIEDataset(Dataset):
|
|
|
|
| 183 |
|
| 184 |
|
| 185 |
if __name__ == '__main__':
|
| 186 |
+
train()
|