Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files
scripts/train_ai_model.py
CHANGED
|
@@ -46,7 +46,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
|
|
| 46 |
if device.type == 'cpu':
|
| 47 |
return 64
|
| 48 |
|
| 49 |
-
print("π Searching for optimal batch size for your GPU...")
|
| 50 |
batch_size = start_batch
|
| 51 |
last_success = batch_size
|
| 52 |
|
|
@@ -74,7 +74,7 @@ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
|
|
| 74 |
except RuntimeError as e:
|
| 75 |
pbar.close()
|
| 76 |
if "out of memory" in str(e).lower():
|
| 77 |
-
print(f"π‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.")
|
| 78 |
torch.cuda.empty_cache()
|
| 79 |
else:
|
| 80 |
raise e
|
|
@@ -112,7 +112,7 @@ def train():
|
|
| 112 |
logger.info("π Starting on-the-fly dataset generation (10 years, 70 symbols)...")
|
| 113 |
build_dataset()
|
| 114 |
|
| 115 |
-
print("π Loading dataset from data/trading_dataset.pt...")
|
| 116 |
data = torch.load("data/trading_dataset.pt")
|
| 117 |
X, y = data["X"], data["y"]
|
| 118 |
|
|
@@ -142,7 +142,7 @@ def train():
|
|
| 142 |
logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
|
| 143 |
|
| 144 |
# 5. Start Training
|
| 145 |
-
print(f"π Starting training loop (Batch Size: {batch_size})...")
|
| 146 |
best_val_loss = float('inf')
|
| 147 |
|
| 148 |
for epoch in range(EPOCHS):
|
|
|
|
| 46 |
if device.type == 'cpu':
|
| 47 |
return 64
|
| 48 |
|
| 49 |
+
print("π Searching for optimal batch size for your GPU...", flush=True)
|
| 50 |
batch_size = start_batch
|
| 51 |
last_success = batch_size
|
| 52 |
|
|
|
|
| 74 |
except RuntimeError as e:
|
| 75 |
pbar.close()
|
| 76 |
if "out of memory" in str(e).lower():
|
| 77 |
+
print(f"π‘ GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.", flush=True)
|
| 78 |
torch.cuda.empty_cache()
|
| 79 |
else:
|
| 80 |
raise e
|
|
|
|
| 112 |
logger.info("π Starting on-the-fly dataset generation (10 years, 70 symbols)...")
|
| 113 |
build_dataset()
|
| 114 |
|
| 115 |
+
print("π Loading dataset from data/trading_dataset.pt...", flush=True)
|
| 116 |
data = torch.load("data/trading_dataset.pt")
|
| 117 |
X, y = data["X"], data["y"]
|
| 118 |
|
|
|
|
| 142 |
logger.info("Starting training on %d samples (%d features)...", len(X), input_dim)
|
| 143 |
|
| 144 |
# 5. Start Training
|
| 145 |
+
print(f"π Starting training loop (Batch Size: {batch_size})...", flush=True)
|
| 146 |
best_val_loss = float('inf')
|
| 147 |
|
| 148 |
for epoch in range(EPOCHS):
|