Spaces:
Running on Zero
Running on Zero
Fix mask type (convert to bool) and handle CPU fallback
Browse files
app.py
CHANGED
|
@@ -94,7 +94,7 @@ def train_model(batch_size, learning_rate, num_epochs):
|
|
| 94 |
model = model.to(device)
|
| 95 |
|
| 96 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
|
| 97 |
-
scaler = torch.amp.GradScaler(enabled=
|
| 98 |
|
| 99 |
scheduler = CurriculumScheduler()
|
| 100 |
logs = []
|
|
@@ -116,9 +116,12 @@ def train_model(batch_size, learning_rate, num_epochs):
|
|
| 116 |
train_loss = 0
|
| 117 |
for batch in train_loader:
|
| 118 |
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 121 |
-
outputs = model(batch["features"], mask=
|
| 122 |
loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"], global_step)
|
| 123 |
|
| 124 |
scaler.scale(loss).backward()
|
|
@@ -140,7 +143,10 @@ def train_model(batch_size, learning_rate, num_epochs):
|
|
| 140 |
with torch.no_grad():
|
| 141 |
for batch in val_loader:
|
| 142 |
batch = {k: v.to(device) for k, v in batch.items()}
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"])
|
| 145 |
val_loss += loss.item()
|
| 146 |
if "predicted_class" in outputs:
|
|
|
|
| 94 |
model = model.to(device)
|
| 95 |
|
| 96 |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
|
| 97 |
+
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
|
| 98 |
|
| 99 |
scheduler = CurriculumScheduler()
|
| 100 |
logs = []
|
|
|
|
| 116 |
train_loss = 0
|
| 117 |
for batch in train_loader:
|
| 118 |
batch = {k: v.to(device) for k, v in batch.items()}
|
| 119 |
+
mask = batch.get("mask")
|
| 120 |
+
if mask is not None:
|
| 121 |
+
mask = mask.bool()
|
| 122 |
|
| 123 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16, enabled=torch.cuda.is_available()):
|
| 124 |
+
outputs = model(batch["features"], mask=mask, active_components=stage_config.get("components"))
|
| 125 |
loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"], global_step)
|
| 126 |
|
| 127 |
scaler.scale(loss).backward()
|
|
|
|
| 143 |
with torch.no_grad():
|
| 144 |
for batch in val_loader:
|
| 145 |
batch = {k: v.to(device) for k, v in batch.items()}
|
| 146 |
+
mask = batch.get("mask")
|
| 147 |
+
if mask is not None:
|
| 148 |
+
mask = mask.bool()
|
| 149 |
+
outputs = model(batch["features"], mask=mask, active_components=stage_config.get("components"))
|
| 150 |
loss, _ = compute_total_loss(outputs, {"labels": batch["labels"]}, stage_config["losses"])
|
| 151 |
val_loss += loss.item()
|
| 152 |
if "predicted_class" in outputs:
|