import os import torch from pix2tex.dataset.latex_dataset import Im2LatexDataset from pix2tex.models import get_model from pix2tex.trainer import build_trainer from pix2tex.utils import set_seed, get_config from pix2tex.tokenizer import LatexTokenizer # Load config config = get_config("train.yaml") # Set CPU-only if CUDA not available or forced config["device"] = "cpu" torch.set_default_tensor_type('torch.FloatTensor') set_seed(config.get("seed", 42)) # Tokenizer tokenizer = LatexTokenizer(config["tokenizer_path"]) # Dataset trainset = Im2LatexDataset(config["train_csv"], tokenizer, config) valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True) # Model model = get_model(config, tokenizer) # Trainer trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset) trainer.train()