Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files- scripts/train_ai_model.py +11 -4
scripts/train_ai_model.py
CHANGED
|
@@ -43,10 +43,17 @@ def train():
|
|
| 43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
logger.info(f"Using device: {device}")
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
|
|
|
| 48 |
dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
# 1. Load Dataset
|
| 52 |
if not os.path.exists("data/trading_dataset.pt"):
|
|
@@ -100,7 +107,7 @@ def train():
|
|
| 100 |
optimizer.zero_grad()
|
| 101 |
|
| 102 |
# Using Mixed Precision (AMP)
|
| 103 |
-
with torch.
|
| 104 |
outputs = model(batch_X)
|
| 105 |
loss = criterion(outputs, batch_y)
|
| 106 |
|
|
@@ -128,7 +135,7 @@ def train():
|
|
| 128 |
with torch.no_grad():
|
| 129 |
for batch_X, batch_y in val_loader:
|
| 130 |
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
| 131 |
-
with torch.
|
| 132 |
outputs = model(batch_X)
|
| 133 |
loss = criterion(outputs, batch_y)
|
| 134 |
val_loss += loss.item()
|
|
|
|
| 43 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
logger.info(f"Using device: {device}")
|
| 45 |
|
| 46 |
+
if device.type == "cpu":
|
| 47 |
+
logger.warning("⚠️ WARNING: CUDA is NOT available. Training on CPU will be EXTREMELY slow.")
|
| 48 |
+
logger.warning("👉 In Google Colab, go to 'Runtime' > 'Change runtime type' and select 'T4 GPU'.")
|
| 49 |
+
|
| 50 |
+
# Modern torch.amp API
|
| 51 |
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 52 |
+
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 53 |
dtype = torch.bfloat16 if use_bf16 else torch.float16
|
| 54 |
+
|
| 55 |
+
# Scaler only needed for FP16 on CUDA
|
| 56 |
+
scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda' and not use_bf16))
|
| 57 |
|
| 58 |
# 1. Load Dataset
|
| 59 |
if not os.path.exists("data/trading_dataset.pt"):
|
|
|
|
| 107 |
optimizer.zero_grad()
|
| 108 |
|
| 109 |
# Using Mixed Precision (AMP)
|
| 110 |
+
with torch.amp.autocast(device_type=device_type, dtype=dtype, enabled=(device.type == 'cuda')):
|
| 111 |
outputs = model(batch_X)
|
| 112 |
loss = criterion(outputs, batch_y)
|
| 113 |
|
|
|
|
| 135 |
with torch.no_grad():
|
| 136 |
for batch_X, batch_y in val_loader:
|
| 137 |
batch_X, batch_y = batch_X.to(device), batch_y.to(device)
|
| 138 |
+
with torch.amp.autocast(device_type=device_type, dtype=dtype, enabled=(device.type == 'cuda')):
|
| 139 |
outputs = model(batch_X)
|
| 140 |
loss = criterion(outputs, batch_y)
|
| 141 |
val_loss += loss.item()
|