Spaces:
Running
Running
Commit
·
ee11b08
1
Parent(s):
caca294
correction
Browse files
HF_LayoutLM_with_Passage.py
CHANGED
|
@@ -203,8 +203,8 @@ class LayoutDataset(Dataset):
|
|
| 203 |
class LayoutLMv3CRF(nn.Module):
|
| 204 |
def __init__(self, model_name, num_labels):
|
| 205 |
super().__init__()
|
| 206 |
-
|
| 207 |
-
self.layoutlm = LayoutLMv3Model.from_pretrained("heerjtdev/edugenius")
|
| 208 |
self.dropout = nn.Dropout(0.1)
|
| 209 |
self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
|
| 210 |
self.crf = CRF(num_labels)
|
|
@@ -302,9 +302,9 @@ def main(args):
|
|
| 302 |
|
| 303 |
# 3. Load and split augmented dataset
|
| 304 |
print("\n--- START PHASE: MODEL/DATASET SETUP ---")
|
| 305 |
-
MODEL_ID = "heerjtdev/edugenius"
|
| 306 |
-
|
| 307 |
-
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_ID)
|
| 308 |
|
| 309 |
dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
|
| 310 |
val_size = int(0.2 * len(dataset))
|
|
@@ -320,8 +320,8 @@ def main(args):
|
|
| 320 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 321 |
print(f"Using device: {device}")
|
| 322 |
# Num_labels is based on the updated 'labels' list
|
| 323 |
-
|
| 324 |
-
model = LayoutLMv3CRF(MODEL_ID, num_labels=len(labels)).to(device)
|
| 325 |
ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
|
| 326 |
os.makedirs("checkpoints", exist_ok=True)
|
| 327 |
if os.path.exists(ckpt_path):
|
|
|
|
| 203 |
class LayoutLMv3CRF(nn.Module):
|
| 204 |
def __init__(self, model_name, num_labels):
|
| 205 |
super().__init__()
|
| 206 |
+
self.layoutlm = LayoutLMv3Model.from_pretrained(model_name)
|
| 207 |
+
# self.layoutlm = LayoutLMv3Model.from_pretrained("heerjtdev/edugenius")
|
| 208 |
self.dropout = nn.Dropout(0.1)
|
| 209 |
self.classifier = nn.Linear(self.layoutlm.config.hidden_size, num_labels)
|
| 210 |
self.crf = CRF(num_labels)
|
|
|
|
| 302 |
|
| 303 |
# 3. Load and split augmented dataset
|
| 304 |
print("\n--- START PHASE: MODEL/DATASET SETUP ---")
|
| 305 |
+
#MODEL_ID = "heerjtdev/edugenius"
|
| 306 |
+
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base")
|
| 307 |
+
#tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_ID)
|
| 308 |
|
| 309 |
dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
|
| 310 |
val_size = int(0.2 * len(dataset))
|
|
|
|
| 320 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 321 |
print(f"Using device: {device}")
|
| 322 |
# Num_labels is based on the updated 'labels' list
|
| 323 |
+
model = LayoutLMv3CRF("microsoft/layoutlmv3-base", num_labels=len(labels)).to(device)
|
| 324 |
+
# model = LayoutLMv3CRF(MODEL_ID, num_labels=len(labels)).to(device)
|
| 325 |
ckpt_path = "checkpoints/layoutlmv3_crf_passage.pth"
|
| 326 |
os.makedirs("checkpoints", exist_ok=True)
|
| 327 |
if os.path.exists(ckpt_path):
|