train aga
Browse files- scripts/train.py +6 -2
scripts/train.py
CHANGED
|
@@ -137,8 +137,12 @@ def main():
|
|
| 137 |
|
| 138 |
# Test forward/backward pass before training
|
| 139 |
print("Testing gradient flow...")
|
| 140 |
-
|
| 141 |
-
test_batch =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
model.train()
|
| 144 |
outputs = model(**test_batch)
|
|
|
|
| 137 |
|
| 138 |
# Test forward/backward pass before training
|
| 139 |
print("Testing gradient flow...")
|
| 140 |
+
test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator)
|
| 141 |
+
test_batch = next(iter(test_loader))
|
| 142 |
+
|
| 143 |
+
# Move batch to model's device
|
| 144 |
+
device = next(model.parameters()).device
|
| 145 |
+
test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}
|
| 146 |
|
| 147 |
model.train()
|
| 148 |
outputs = model(**test_batch)
|