Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
|
@@ -1,30 +1,30 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
| 3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
except Exception as e:
|
| 12 |
-
print("β οΈ Failed to load config:", str(e))
|
| 13 |
-
return f"β Failed to load config: {str(e)}"
|
| 14 |
|
| 15 |
-
|
| 16 |
-
epochs = config.get("training", {}).get("epochs", 5)
|
| 17 |
-
lr = config.get("training", {}).get("learning_rate", 0.001)
|
| 18 |
-
batch_size = config.get("training", {}).get("batch_size", 32)
|
| 19 |
-
device = config.get("training", {}).get("device", "cpu")
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
print(f"π Epochs: {epochs}, Batch Size: {batch_size}, Learning Rate: {lr}")
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from pix2tex.dataset.latex_dataset import Im2LatexDataset
|
| 4 |
+
from pix2tex.models import get_model
|
| 5 |
+
from pix2tex.trainer import build_trainer
|
| 6 |
+
from pix2tex.utils import set_seed, get_config
|
| 7 |
+
from pix2tex.tokenizer import LatexTokenizer
|
| 8 |
|
| 9 |
+
# Load config
|
| 10 |
+
config = get_config("train.yaml")
|
| 11 |
|
| 12 |
+
# Set CPU-only if CUDA not available or forced
|
| 13 |
+
config["device"] = "cpu"
|
| 14 |
+
torch.set_default_tensor_type('torch.FloatTensor')
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
set_seed(config.get("seed", 42))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# Tokenizer
|
| 19 |
+
tokenizer = LatexTokenizer(config["tokenizer_path"])
|
|
|
|
| 20 |
|
| 21 |
+
# Dataset
|
| 22 |
+
trainset = Im2LatexDataset(config["train_csv"], tokenizer, config)
|
| 23 |
+
valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True)
|
| 24 |
|
| 25 |
+
# Model
|
| 26 |
+
model = get_model(config, tokenizer)
|
| 27 |
+
|
| 28 |
+
# Trainer
|
| 29 |
+
trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset)
|
| 30 |
+
trainer.train()
|