luohoa97 commited on
Commit
b93d3aa
·
verified ·
1 Parent(s): a17fae6

Deploy BitNet-Transformer Trainer

Browse files
Files changed (1) hide show
  1. scripts/train_ai_model.py +9 -2
scripts/train_ai_model.py CHANGED
@@ -40,6 +40,9 @@ HF_DATASET_ID = "luohoa97/BitFin" # User's dataset repo
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
42
  def train():
 
 
 
43
  # 1. Load Dataset
44
  if not os.path.exists("data/trading_dataset.pt"):
45
  logger.info("Dataset not found locally. Searching on HF Hub...")
@@ -68,6 +71,7 @@ def train():
68
  # 3. Create Model
69
  input_dim = X.shape[2]
70
  model = create_model(input_dim=input_dim, hidden_dim=HIDDEN_DIM, layers=LAYERS, seq_len=SEQ_LEN)
 
71
 
72
  total_params = sum(p.numel() for p in model.parameters())
73
  logger.info(f"Model Architecture: BitNet-Transformer ({LAYERS} layers, {HIDDEN_DIM} hidden)")
@@ -87,6 +91,7 @@ def train():
87
  total = 0
88
 
89
  for batch_X, batch_y in train_loader:
 
90
  optimizer.zero_grad()
91
  outputs = model(batch_X)
92
  loss = criterion(outputs, batch_y)
@@ -109,6 +114,7 @@ def train():
109
  val_total = 0
110
  with torch.no_grad():
111
  for batch_X, batch_y in val_loader:
 
112
  outputs = model(batch_X)
113
  loss = criterion(outputs, batch_y)
114
  val_loss += loss.item()
@@ -142,10 +148,11 @@ def train():
142
 
143
  with torch.no_grad():
144
  for xb, yb in val_loader:
 
145
  outputs = model(xb)
146
  preds = torch.argmax(outputs, dim=-1)
147
- all_preds.extend(preds.numpy())
148
- all_true.extend(yb.numpy())
149
 
150
  target_names = ["HOLD", "BUY", "SELL"]
151
  report = classification_report(all_true, all_preds, target_names=target_names)
 
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
42
  def train():
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ logger.info(f"Using device: {device}")
45
+
46
  # 1. Load Dataset
47
  if not os.path.exists("data/trading_dataset.pt"):
48
  logger.info("Dataset not found locally. Searching on HF Hub...")
 
71
  # 3. Create Model
72
  input_dim = X.shape[2]
73
  model = create_model(input_dim=input_dim, hidden_dim=HIDDEN_DIM, layers=LAYERS, seq_len=SEQ_LEN)
74
+ model.to(device)
75
 
76
  total_params = sum(p.numel() for p in model.parameters())
77
  logger.info(f"Model Architecture: BitNet-Transformer ({LAYERS} layers, {HIDDEN_DIM} hidden)")
 
91
  total = 0
92
 
93
  for batch_X, batch_y in train_loader:
94
+ batch_X, batch_y = batch_X.to(device), batch_y.to(device)
95
  optimizer.zero_grad()
96
  outputs = model(batch_X)
97
  loss = criterion(outputs, batch_y)
 
114
  val_total = 0
115
  with torch.no_grad():
116
  for batch_X, batch_y in val_loader:
117
+ batch_X, batch_y = batch_X.to(device), batch_y.to(device)
118
  outputs = model(batch_X)
119
  loss = criterion(outputs, batch_y)
120
  val_loss += loss.item()
 
148
 
149
  with torch.no_grad():
150
  for xb, yb in val_loader:
151
+ xb, yb = xb.to(device), yb.to(device)
152
  outputs = model(xb)
153
  preds = torch.argmax(outputs, dim=-1)
154
+ all_preds.extend(preds.cpu().numpy())
155
+ all_true.extend(yb.cpu().numpy())
156
 
157
  target_names = ["HOLD", "BUY", "SELL"]
158
  report = classification_report(all_true, all_preds, target_names=target_names)