Spaces:
Runtime error
Runtime error
Deploy BitNet-Transformer Trainer
Browse files
scripts/train_ai_model.py
CHANGED
|
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|
| 28 |
|
| 29 |
# Hyperparameters
|
| 30 |
EPOCHS = 100
|
| 31 |
-
BATCH_SIZE =
|
| 32 |
LR = 0.0003
|
| 33 |
HIDDEN_DIM = 512
|
| 34 |
LAYERS = 8
|
|
@@ -77,8 +77,8 @@ def train():
|
|
| 77 |
val_size = len(dataset) - train_size
|
| 78 |
train_ds, val_ds = random_split(dataset, [train_size, val_size])
|
| 79 |
|
| 80 |
-
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
|
| 81 |
-
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, pin_memory=True)
|
| 82 |
|
| 83 |
# 3. Create Model
|
| 84 |
input_dim = X.shape[2]
|
|
|
|
| 28 |
|
| 29 |
# Hyperparameters
|
| 30 |
EPOCHS = 100
|
| 31 |
+
BATCH_SIZE = 1024 # Significant increase for T4/A100 utilization
|
| 32 |
LR = 0.0003
|
| 33 |
HIDDEN_DIM = 512
|
| 34 |
LAYERS = 8
|
|
|
|
| 77 |
val_size = len(dataset) - train_size
|
| 78 |
train_ds, val_ds = random_split(dataset, [train_size, val_size])
|
| 79 |
|
| 80 |
+
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
|
| 81 |
+
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, pin_memory=True, num_workers=2)
|
| 82 |
|
| 83 |
# 3. Create Model
|
| 84 |
input_dim = X.shape[2]
|