Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from pix2tex.dataset.latex_dataset import Im2LatexDataset | |
| from pix2tex.models import get_model | |
| from pix2tex.trainer import build_trainer | |
| from pix2tex.utils import set_seed, get_config | |
| from pix2tex.tokenizer import LatexTokenizer | |
| # Load config | |
| config = get_config("train.yaml") | |
| # Set CPU-only if CUDA not available or forced | |
| config["device"] = "cpu" | |
| torch.set_default_tensor_type('torch.FloatTensor') | |
| set_seed(config.get("seed", 42)) | |
| # Tokenizer | |
| tokenizer = LatexTokenizer(config["tokenizer_path"]) | |
| # Dataset | |
| trainset = Im2LatexDataset(config["train_csv"], tokenizer, config) | |
| valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True) | |
| # Model | |
| model = get_model(config, tokenizer) | |
| # Trainer | |
| trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset) | |
| trainer.train() | |