Spaces:
Running
Running
removes callbacks
Browse files- monitoring.py +4 -0
- trainer.py +43 -4
monitoring.py
CHANGED
|
@@ -309,6 +309,10 @@ class SmolLM3Monitor:
|
|
| 309 |
self.monitor.close()
|
| 310 |
except Exception as e:
|
| 311 |
logger.error(f"Error in on_train_end: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
return TrackioCallback(self)
|
| 314 |
|
|
|
|
| 309 |
self.monitor.close()
|
| 310 |
except Exception as e:
|
| 311 |
logger.error(f"Error in on_train_end: {e}")
|
| 312 |
+
|
| 313 |
+
def __call__(self, *args, **kwargs):
|
| 314 |
+
"""Make the callback callable to avoid any issues"""
|
| 315 |
+
return self
|
| 316 |
|
| 317 |
return TrackioCallback(self)
|
| 318 |
|
trainer.py
CHANGED
|
@@ -61,12 +61,51 @@ class SmolLM3Trainer:
|
|
| 61 |
# Get data collator
|
| 62 |
data_collator = self.dataset.get_data_collator()
|
| 63 |
|
| 64 |
-
# Add monitoring callback
|
| 65 |
callbacks = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if self.monitor and self.monitor.enable_tracking:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
if self.use_sft_trainer:
|
| 72 |
# Use SFTTrainer for supervised fine-tuning
|
|
|
|
| 61 |
# Get data collator
|
| 62 |
data_collator = self.dataset.get_data_collator()
|
| 63 |
|
| 64 |
+
# Add monitoring callback - temporarily disabled to debug
|
| 65 |
callbacks = []
|
| 66 |
+
|
| 67 |
+
# Simple console callback for basic monitoring
|
| 68 |
+
class SimpleConsoleCallback:
|
| 69 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 70 |
+
"""Log metrics to console"""
|
| 71 |
+
if logs and isinstance(logs, dict):
|
| 72 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 73 |
+
loss = logs.get('loss', 'N/A')
|
| 74 |
+
lr = logs.get('learning_rate', 'N/A')
|
| 75 |
+
print(f"Step {step}: loss={loss:.4f}, lr={lr}")
|
| 76 |
+
|
| 77 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 78 |
+
print("🚀 Training started!")
|
| 79 |
+
|
| 80 |
+
def on_train_end(self, args, state, control, **kwargs):
|
| 81 |
+
print("✅ Training completed!")
|
| 82 |
+
|
| 83 |
+
def on_save(self, args, state, control, **kwargs):
|
| 84 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 85 |
+
print(f"💾 Checkpoint saved at step {step}")
|
| 86 |
+
|
| 87 |
+
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 88 |
+
if metrics and isinstance(metrics, dict):
|
| 89 |
+
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 90 |
+
eval_loss = metrics.get('eval_loss', 'N/A')
|
| 91 |
+
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
| 92 |
+
|
| 93 |
+
# Add simple console callback
|
| 94 |
+
callbacks.append(SimpleConsoleCallback())
|
| 95 |
+
logger.info("Added simple console monitoring callback")
|
| 96 |
+
|
| 97 |
+
# Try to add Trackio callback if available
|
| 98 |
if self.monitor and self.monitor.enable_tracking:
|
| 99 |
+
try:
|
| 100 |
+
trackio_callback = self.monitor.create_monitoring_callback()
|
| 101 |
+
if trackio_callback:
|
| 102 |
+
callbacks.append(trackio_callback)
|
| 103 |
+
logger.info("Added Trackio monitoring callback")
|
| 104 |
+
else:
|
| 105 |
+
logger.warning("Failed to create Trackio callback")
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Error creating Trackio callback: {e}")
|
| 108 |
+
logger.info("Continuing with console monitoring only")
|
| 109 |
|
| 110 |
if self.use_sft_trainer:
|
| 111 |
# Use SFTTrainer for supervised fine-tuning
|