luohoa97 commited on
Commit
20b4890
·
verified ·
1 Parent(s): be03d5f

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +11 -4
scripts/train_ai_model.py CHANGED
@@ -43,10 +43,17 @@ def train():
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  logger.info(f"Using device: {device}")
45
 
46
- # Use BFloat16 if supported (Ampere+ GPUs like A100/H100), otherwise FP16
 
 
 
 
47
  use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
 
48
  dtype = torch.bfloat16 if use_bf16 else torch.float16
49
- scaler = torch.cuda.amp.GradScaler(enabled=(not use_bf16)) # Scaler only needed for FP16
 
 
50
 
51
  # 1. Load Dataset
52
  if not os.path.exists("data/trading_dataset.pt"):
@@ -100,7 +107,7 @@ def train():
100
  optimizer.zero_grad()
101
 
102
  # Using Mixed Precision (AMP)
103
- with torch.cuda.amp.autocast(dtype=dtype):
104
  outputs = model(batch_X)
105
  loss = criterion(outputs, batch_y)
106
 
@@ -128,7 +135,7 @@ def train():
128
  with torch.no_grad():
129
  for batch_X, batch_y in val_loader:
130
  batch_X, batch_y = batch_X.to(device), batch_y.to(device)
131
- with torch.cuda.amp.autocast(dtype=dtype):
132
  outputs = model(batch_X)
133
  loss = criterion(outputs, batch_y)
134
  val_loss += loss.item()
 
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  logger.info(f"Using device: {device}")
45
 
46
+ if device.type == "cpu":
47
+ logger.warning("⚠️ WARNING: CUDA is NOT available. Training on CPU will be EXTREMELY slow.")
48
+ logger.warning("👉 In Google Colab, go to 'Runtime' > 'Change runtime type' and select 'T4 GPU'.")
49
+
50
+ # Modern torch.amp API
51
  use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
52
+ device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
53
  dtype = torch.bfloat16 if use_bf16 else torch.float16
54
+
55
+ # Scaler only needed for FP16 on CUDA
56
+ scaler = torch.amp.GradScaler('cuda', enabled=(device.type == 'cuda' and not use_bf16))
57
 
58
  # 1. Load Dataset
59
  if not os.path.exists("data/trading_dataset.pt"):
 
107
  optimizer.zero_grad()
108
 
109
  # Using Mixed Precision (AMP)
110
+ with torch.amp.autocast(device_type=device_type, dtype=dtype, enabled=(device.type == 'cuda')):
111
  outputs = model(batch_X)
112
  loss = criterion(outputs, batch_y)
113
 
 
135
  with torch.no_grad():
136
  for batch_X, batch_y in val_loader:
137
  batch_X, batch_y = batch_X.to(device), batch_y.to(device)
138
+ with torch.amp.autocast(device_type=device_type, dtype=dtype, enabled=(device.type == 'cuda')):
139
  outputs = model(batch_X)
140
  loss = criterion(outputs, batch_y)
141
  val_loss += loss.item()