Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files- scripts/train_ai_model.py +13 -1
scripts/train_ai_model.py
CHANGED
|
@@ -10,6 +10,8 @@ import torch
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.optim as optim
|
| 12 |
from torch.utils.data import DataLoader, TensorDataset, random_split
|
|
|
|
|
|
|
| 13 |
import logging
|
| 14 |
from safetensors.torch import save_file, load_file
|
| 15 |
from huggingface_hub import HfApi, create_repo, hf_hub_download
|
|
@@ -103,6 +105,7 @@ def train():
|
|
| 103 |
logger.info("π Starting on-the-fly dataset generation (10 years, 70 symbols)...")
|
| 104 |
build_dataset()
|
| 105 |
|
|
|
|
| 106 |
data = torch.load("data/trading_dataset.pt")
|
| 107 |
X, y = data["X"], data["y"]
|
| 108 |
|
|
@@ -131,6 +134,8 @@ def train():
|
|
| 131 |
|
| 132 |
logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
|
| 133 |
|
|
|
|
|
|
|
| 134 |
best_val_loss = float('inf')
|
| 135 |
|
| 136 |
for epoch in range(EPOCHS):
|
|
@@ -139,7 +144,8 @@ def train():
|
|
| 139 |
correct = 0
|
| 140 |
total = 0
|
| 141 |
|
| 142 |
-
|
|
|
|
| 143 |
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
| 144 |
optimizer.zero_grad()
|
| 145 |
|
|
@@ -164,6 +170,12 @@ def train():
|
|
| 164 |
total += batch_y.size(0)
|
| 165 |
correct += predicted.eq(batch_y).sum().item()
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
# Validation
|
| 168 |
model.eval()
|
| 169 |
val_loss = 0
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.optim as optim
|
| 12 |
from torch.utils.data import DataLoader, TensorDataset, random_split
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
import logging
|
| 16 |
from safetensors.torch import save_file, load_file
|
| 17 |
from huggingface_hub import HfApi, create_repo, hf_hub_download
|
|
|
|
| 105 |
logger.info("π Starting on-the-fly dataset generation (10 years, 70 symbols)...")
|
| 106 |
build_dataset()
|
| 107 |
|
| 108 |
+
logger.info("π Loading dataset from data/trading_dataset.pt...")
|
| 109 |
data = torch.load("data/trading_dataset.pt")
|
| 110 |
X, y = data["X"], data["y"]
|
| 111 |
|
|
|
|
| 134 |
|
| 135 |
logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
|
| 136 |
|
| 137 |
+
# 5. Start Training
|
| 138 |
+
logger.info("π Starting training loop...")
|
| 139 |
best_val_loss = float('inf')
|
| 140 |
|
| 141 |
for epoch in range(EPOCHS):
|
|
|
|
| 144 |
correct = 0
|
| 145 |
total = 0
|
| 146 |
|
| 147 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
|
| 148 |
+
for batch_X, batch_y in pbar:
|
| 149 |
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
| 150 |
optimizer.zero_grad()
|
| 151 |
|
|
|
|
| 170 |
total += batch_y.size(0)
|
| 171 |
correct += predicted.eq(batch_y).sum().item()
|
| 172 |
|
| 173 |
+
# Update progress bar
|
| 174 |
+
pbar.set_postfix({
|
| 175 |
+
"loss": f"{loss.item():.4f}",
|
| 176 |
+
"acc": f"{100.*correct/total:.1f}%"
|
| 177 |
+
})
|
| 178 |
+
|
| 179 |
# Validation
|
| 180 |
model.eval()
|
| 181 |
val_loss = 0
|