Charlie81 commited on
Commit
1f3825f
·
1 Parent(s): 325d2d0
Files changed (1) hide show
  1. scripts/train.py +32 -9
scripts/train.py CHANGED
@@ -79,7 +79,7 @@ def main():
79
  save_steps=1000,
80
  save_total_limit=2,
81
  bf16=True,
82
- gradient_checkpointing=False, # Disabled for now to debug
83
  report_to="tensorboard",
84
  optim="adamw_torch",
85
  lr_scheduler_type="cosine",
@@ -112,27 +112,50 @@ def main():
112
  batch["output_router_logits"] = True
113
  return batch
114
 
115
- # Custom trainer class to handle gradient flow
116
  class CustomTrainer(Trainer):
117
  def compute_loss(self, model, inputs, return_outputs=False):
118
- outputs = model(**inputs)
119
- loss = outputs.loss
120
 
121
- # Ensure we have gradients
122
- if loss.requires_grad:
 
 
 
 
 
123
  return (loss, outputs) if return_outputs else loss
124
- else:
125
- raise RuntimeError("Loss doesn't require gradients. Check model parameters.")
126
 
127
  # Initialize trainer
128
  trainer = CustomTrainer(
129
  model=model,
130
  args=training_args,
131
  train_dataset=tokenized_dataset,
132
- tokenizer=tokenizer,
133
  data_collator=data_collator,
134
  )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # Train
137
  print("Starting training...")
138
  trainer.train()
 
79
  save_steps=1000,
80
  save_total_limit=2,
81
  bf16=True,
82
+ gradient_checkpointing=False, # Disabled for now
83
  report_to="tensorboard",
84
  optim="adamw_torch",
85
  lr_scheduler_type="cosine",
 
112
  batch["output_router_logits"] = True
113
  return batch
114
 
115
+ # Fixed CustomTrainer class
116
  class CustomTrainer(Trainer):
117
  def compute_loss(self, model, inputs, return_outputs=False):
118
+ # Remove num_items_in_batch from inputs if present
119
+ inputs.pop('num_items_in_batch', None)
120
 
121
+ with torch.set_grad_enabled(True): # Ensure gradients are enabled
122
+ outputs = model(**inputs)
123
+ loss = outputs.loss
124
+
125
+ if not loss.requires_grad:
126
+ raise RuntimeError("Loss doesn't require gradients. Check model parameters.")
127
+
128
  return (loss, outputs) if return_outputs else loss
 
 
129
 
130
  # Initialize trainer
131
  trainer = CustomTrainer(
132
  model=model,
133
  args=training_args,
134
  train_dataset=tokenized_dataset,
 
135
  data_collator=data_collator,
136
  )
137
 
138
+ # Test forward/backward pass before training
139
+ print("Testing gradient flow...")
140
+ test_batch = next(iter(DataLoader(tokenized_dataset, batch_size=1)))
141
+ test_batch = {k: v.to(model.device) for k, v in test_batch.items()}
142
+
143
+ model.train()
144
+ outputs = model(**test_batch)
145
+ loss = outputs.loss
146
+ print(f"Initial loss: {loss.item()}")
147
+
148
+ loss.backward()
149
+ print("Gradients computed successfully")
150
+
151
+ # Check which parameters received gradients
152
+ for name, param in model.named_parameters():
153
+ if param.grad is not None:
154
+ print(f"Parameter {name} received gradients")
155
+
156
+ # Reset gradients
157
+ model.zero_grad()
158
+
159
  # Train
160
  print("Starting training...")
161
  trainer.train()