train agaaa
Browse files- scripts/train.py +9 -5
scripts/train.py
CHANGED
|
@@ -112,13 +112,17 @@ def main():
|
|
| 112 |
batch["output_router_logits"] = True
|
| 113 |
return batch
|
| 114 |
|
| 115 |
-
# Fixed CustomTrainer class
|
| 116 |
class CustomTrainer(Trainer):
|
| 117 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
| 118 |
-
# Remove
|
| 119 |
-
inputs.
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
outputs = model(**inputs)
|
| 123 |
loss = outputs.loss
|
| 124 |
|
|
|
|
| 112 |
batch["output_router_logits"] = True
|
| 113 |
return batch
|
| 114 |
|
| 115 |
+
# Fixed CustomTrainer class that handles all possible arguments
|
| 116 |
class CustomTrainer(Trainer):
|
| 117 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
| 118 |
+
# Remove any unexpected arguments
|
| 119 |
+
inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']}
|
| 120 |
|
| 121 |
+
# Ensure we're in training mode
|
| 122 |
+
model.train()
|
| 123 |
+
|
| 124 |
+
# Forward pass with gradients
|
| 125 |
+
with torch.set_grad_enabled(True):
|
| 126 |
outputs = model(**inputs)
|
| 127 |
loss = outputs.loss
|
| 128 |
|