LimmeDev commited on
Commit
3ecafdf
·
verified ·
1 Parent(s): 4cbf142

Fix mask type (convert to bool) and handle CPU fallback

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -94,7 +94,7 @@ def train_model(batch_size, learning_rate, num_epochs):
94
  model = model.to(device)
95
 
96
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
97
- scaler = torch.amp.GradScaler(enabled=True)
98
 
99
  scheduler = CurriculumScheduler()
100
  logs = []
@@ -116,9 +116,12 @@ def train_model(batch_size, learning_rate, num_epochs):
116
  train_loss = 0
117
  for batch in train_loader:
118
  batch = {k: v.to(device) for k, v in batch.items()}
 
 
 
119
 
120
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
121
- outputs = model(batch["features"], mask=batch.get("mask"), active_components=stage_config.get("components"))
122
  loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"], global_step)
123
 
124
  scaler.scale(loss).backward()
@@ -140,7 +143,10 @@ def train_model(batch_size, learning_rate, num_epochs):
140
  with torch.no_grad():
141
  for batch in val_loader:
142
  batch = {k: v.to(device) for k, v in batch.items()}
143
- outputs = model(batch["features"], mask=batch.get("mask"), active_components=stage_config.get("components"))
 
 
 
144
  loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"])
145
  val_loss += loss.item()
146
  if "predicted_class" in outputs:
 
94
  model = model.to(device)
95
 
96
  optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
97
+ scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
98
 
99
  scheduler = CurriculumScheduler()
100
  logs = []
 
116
  train_loss = 0
117
  for batch in train_loader:
118
  batch = {k: v.to(device) for k, v in batch.items()}
119
+ mask = batch.get("mask")
120
+ if mask is not None:
121
+ mask = mask.bool()
122
 
123
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=torch.cuda.is_available()):
124
+ outputs = model(batch["features"], mask=mask, active_components=stage_config.get("components"))
125
  loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"], global_step)
126
 
127
  scaler.scale(loss).backward()
 
143
  with torch.no_grad():
144
  for batch in val_loader:
145
  batch = {k: v.to(device) for k, v in batch.items()}
146
+ mask = batch.get("mask")
147
+ if mask is not None:
148
+ mask = mask.bool()
149
+ outputs = model(batch["features"], mask=mask, active_components=stage_config.get("components"))
150
  loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"])
151
  val_loss += loss.item()
152
  if "predicted_class" in outputs: