Project / train.py
MasteredUltraInstinct's picture
Update train.py
99f4765 verified
raw
history blame contribute delete
836 Bytes
import os
import torch
from pix2tex.dataset.latex_dataset import Im2LatexDataset
from pix2tex.models import get_model
from pix2tex.trainer import build_trainer
from pix2tex.utils import set_seed, get_config
from pix2tex.tokenizer import LatexTokenizer
# Load config
config = get_config("train.yaml")
# Set CPU-only if CUDA not available or forced
config["device"] = "cpu"
torch.set_default_tensor_type('torch.FloatTensor')
set_seed(config.get("seed", 42))
# Tokenizer
tokenizer = LatexTokenizer(config["tokenizer_path"])
# Dataset
trainset = Im2LatexDataset(config["train_csv"], tokenizer, config)
valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True)
# Model
model = get_model(config, tokenizer)
# Trainer
trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset)
trainer.train()