MasteredUltraInstinct commited on
Commit
99f4765
Β·
verified Β·
1 Parent(s): 54c5571

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +24 -24
train.py CHANGED
@@ -1,30 +1,30 @@
1
- # train.py
2
- import time
3
- import yaml
 
 
 
 
4
 
5
- def train_model(config_file="train.yaml"):
6
- print("🧠 Starting training...")
7
 
8
- try:
9
- with open(config_file, "r") as f:
10
- config = yaml.safe_load(f)
11
- except Exception as e:
12
- print("⚠️ Failed to load config:", str(e))
13
- return f"❌ Failed to load config: {str(e)}"
14
 
15
- model_name = config.get("model", {}).get("name", "default_model")
16
- epochs = config.get("training", {}).get("epochs", 5)
17
- lr = config.get("training", {}).get("learning_rate", 0.001)
18
- batch_size = config.get("training", {}).get("batch_size", 32)
19
- device = config.get("training", {}).get("device", "cpu")
20
 
21
- print(f"πŸ“¦ Model: {model_name}")
22
- print(f"πŸ”§ Device: {device}")
23
- print(f"πŸ“š Epochs: {epochs}, Batch Size: {batch_size}, Learning Rate: {lr}")
24
 
25
- for epoch in range(1, epochs + 1):
26
- print(f"πŸŒ€ Epoch {epoch}/{epochs} ...")
27
- time.sleep(1) # Simulate work
28
 
29
- print("βœ… Training complete.")
30
- return f"βœ… Dummy training for `{model_name}` finished on `{device}`!"
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pix2tex.dataset.latex_dataset import Im2LatexDataset
4
+ from pix2tex.models import get_model
5
+ from pix2tex.trainer import build_trainer
6
+ from pix2tex.utils import set_seed, get_config
7
+ from pix2tex.tokenizer import LatexTokenizer
8
 
9
+ # Load config
10
+ config = get_config("train.yaml")
11
 
12
+ # Set CPU-only if CUDA not available or forced
13
+ config["device"] = "cpu"
14
+ torch.set_default_tensor_type('torch.FloatTensor')
 
 
 
15
 
16
+ set_seed(config.get("seed", 42))
 
 
 
 
17
 
18
+ # Tokenizer
19
+ tokenizer = LatexTokenizer(config["tokenizer_path"])
 
20
 
21
+ # Dataset
22
+ trainset = Im2LatexDataset(config["train_csv"], tokenizer, config)
23
+ valset = Im2LatexDataset(config["val_csv"], tokenizer, config, is_val=True)
24
 
25
+ # Model
26
+ model = get_model(config, tokenizer)
27
+
28
+ # Trainer
29
+ trainer = build_trainer(model, tokenizer, config, trainset=trainset, valset=valset)
30
+ trainer.train()