File size: 836 Bytes
99f4765
 
 
 
 
 
 
ccca011
99f4765
 
ccca011
99f4765
 
 
ccca011
99f4765
ccca011
99f4765
 
ccca011
99f4765
 
 
ccca011
99f4765
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
from pix2tex.dataset.latex_dataset import Im2LatexDataset
from pix2tex.models import get_model
from pix2tex.trainer import build_trainer
from pix2tex.utils import set_seed, get_config
from pix2tex.tokenizer import LatexTokenizer

# Load config
config = get_config("train.yaml")

# Set CPU-only if CUDA not available or forced
config["device"] = "cpu"
torch.set_default_tensor_type('torch.FloatTensor')

set_seed(config.get("seed", 42))

# Tokenizer
tokenizer = LatexTokenizer(config["tokenizer_path"])

# Dataset
trainset = Im2LatexDataset(config["train_csv"], tokenizer, config)
valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True)

# Model
model = get_model(config, tokenizer)

# Trainer
trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset)
trainer.train()