Charlie81 commited on
Commit
45d6e50
·
1 Parent(s): 580eff8

train agaaa

Browse files
Files changed (1) hide show
  1. scripts/train.py +9 -5
scripts/train.py CHANGED
@@ -112,13 +112,17 @@ def main():
112
  batch["output_router_logits"] = True
113
  return batch
114
 
115
- # Fixed CustomTrainer class
116
  class CustomTrainer(Trainer):
117
- def compute_loss(self, model, inputs, return_outputs=False):
118
- # Remove num_items_in_batch from inputs if present
119
- inputs.pop('num_items_in_batch', None)
120
 
121
- with torch.set_grad_enabled(True): # Ensure gradients are enabled
 
 
 
 
122
  outputs = model(**inputs)
123
  loss = outputs.loss
124
 
 
112
  batch["output_router_logits"] = True
113
  return batch
114
 
115
+ # Fixed CustomTrainer class that handles all possible arguments
116
  class CustomTrainer(Trainer):
117
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
118
+ # Remove any unexpected arguments
119
+ inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']}
120
 
121
+ # Ensure we're in training mode
122
+ model.train()
123
+
124
+ # Forward pass with gradients
125
+ with torch.set_grad_enabled(True):
126
  outputs = model(**inputs)
127
  loss = outputs.loss
128