luohoa97 commited on
Commit
68e57b2
Β·
verified Β·
1 Parent(s): 362e261

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +4 -4
scripts/train_ai_model.py CHANGED
@@ -46,7 +46,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
46
  if device.type == 'cpu':
47
  return 64
48
 
49
- print("πŸ” Searching for optimal batch size for your GPU...")
50
  batch_size = start_batch
51
  last_success = batch_size
52
 
@@ -74,7 +74,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
74
  except RuntimeError as e:
75
  pbar.close()
76
  if "out of memory" in str(e).lower():
77
- print(f"πŸ’‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.")
78
  torch.cuda.empty_cache()
79
  else:
80
  raise e
@@ -112,7 +112,7 @@ def train():
112
  logger.info("πŸš€ Starting on-the-fly dataset generation (10 years, 70 symbols)...")
113
  build_dataset()
114
 
115
- print("πŸš€ Loading dataset from data/trading_dataset.pt...")
116
  data = torch.load("data/trading_dataset.pt")
117
  X, y = data["X"], data["y"]
118
 
@@ -142,7 +142,7 @@ def train():
142
  logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
143
 
144
  # 5. Start Training
145
- print(f"πŸš€ Starting training loop (Batch Size: {batch_size})...")
146
  best_val_loss = float('inf')
147
 
148
  for epoch in range(EPOCHS):
 
46
  if device.type == 'cpu':
47
  return 64
48
 
49
+ print("πŸ” Searching for optimal batch size for your GPU...", flush=True)
50
  batch_size = start_batch
51
  last_success = batch_size
52
 
 
74
  except RuntimeError as e:
75
  pbar.close()
76
  if "out of memory" in str(e).lower():
77
+ print(f"πŸ’‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.", flush=True)
78
  torch.cuda.empty_cache()
79
  else:
80
  raise e
 
112
  logger.info("πŸš€ Starting on-the-fly dataset generation (10 years, 70 symbols)...")
113
  build_dataset()
114
 
115
+ print("πŸš€ Loading dataset from data/trading_dataset.pt...", flush=True)
116
  data = torch.load("data/trading_dataset.pt")
117
  X, y = data["X"], data["y"]
118
 
 
142
  logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
143
 
144
  # 5. Start Training
145
+ print(f"πŸš€ Starting training loop (Batch Size: {batch_size})...", flush=True)
146
  best_val_loss = float('inf')
147
 
148
  for epoch in range(EPOCHS):