luohoa97 commited on
Commit
26f4391
·
verified ·
1 Parent(s): 2abfd60

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +40 -3
scripts/train_ai_model.py CHANGED
@@ -39,6 +39,40 @@ HF_REPO_ID = os.getenv("HF_REPO_ID", "luohoa97/BitFin") # User's model repo
39
  HF_DATASET_ID = "luohoa97/BitFin" # User's dataset repo
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def train():
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  logger.info(f"Using device: {device}")
@@ -77,14 +111,17 @@ def train():
77
  val_size = len(dataset) - train_size
78
  train_ds, val_ds = random_split(dataset, [train_size, val_size])
79
 
80
- train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=2)
81
- val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, pin_memory=True, num_workers=2)
82
-
83
  # 3. Create Model
84
  input_dim = X.shape[2]
85
  model = create_model(input_dim=input_dim, hidden_dim=HIDDEN_DIM, layers=LAYERS, seq_len=SEQ_LEN)
86
  model.to(device)
87
 
 
 
 
 
 
 
88
  total_params = sum(p.numel() for p in model.parameters())
89
  logger.info(f"Model Architecture: BitNet-Transformer ({LAYERS} layers, {HIDDEN_DIM} hidden)")
90
  logger.info(f"Total Parameters: {total_params:,}")
 
39
  HF_DATASET_ID = "luohoa97/BitFin" # User's dataset repo
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
42
+ def get_max_batch_size(model, input_dim, seq_len, device, start_batch=128):
43
+ """Automatically find the largest batch size that fits in VRAM."""
44
+ if device.type == 'cpu':
45
+ return 64
46
+
47
+ logger.info("🔍 Searching for optimal batch size for your GPU...")
48
+ batch_size = start_batch
49
+ last_success = batch_size
50
+
51
+ try:
52
+ while batch_size <= 16384: # Ceiling
53
+ # Mock data for testing
54
+ mock_X = torch.randn(batch_size, seq_len, input_dim).to(device)
55
+ mock_y = torch.randint(0, 3, (batch_size,)).to(device)
56
+
57
+ # Simulated forward/backward pass
58
+ outputs = model(mock_X)
59
+ loss = nn.CrossEntropyLoss()(outputs, mock_y)
60
+ loss.backward()
61
+ model.zero_grad()
62
+
63
+ last_success = batch_size
64
+ batch_size *= 2
65
+ torch.cuda.empty_cache()
66
+
67
+ except RuntimeError as e:
68
+ if "out of memory" in str(e).lower():
69
+ logger.info(f"💡 GPU Hit limit at {batch_size}. Using {last_success} as optimal batch.")
70
+ torch.cuda.empty_cache()
71
+ else:
72
+ raise e
73
+
74
+ return last_success
75
+
76
  def train():
77
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
  logger.info(f"Using device: {device}")
 
111
  val_size = len(dataset) - train_size
112
  train_ds, val_ds = random_split(dataset, [train_size, val_size])
113
 
 
 
 
114
  # 3. Create Model
115
  input_dim = X.shape[2]
116
  model = create_model(input_dim=input_dim, hidden_dim=HIDDEN_DIM, layers=LAYERS, seq_len=SEQ_LEN)
117
  model.to(device)
118
 
119
+ # 4. Dynamic Batch Sizing
120
+ batch_size = get_max_batch_size(model, input_dim, SEQ_LEN, device)
121
+
122
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
123
+ val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=2)
124
+
125
  total_params = sum(p.numel() for p in model.parameters())
126
  logger.info(f"Model Architecture: BitNet-Transformer ({LAYERS} layers, {HIDDEN_DIM} hidden)")
127
  logger.info(f"Total Parameters: {total_params:,}")