Spaces:
Sleeping
Sleeping
adds local and remote training monitors to config
Browse files- scripts/training/train_gpt_oss.py +47 -1
- src/monitoring.py +19 -18
- src/trainer.py +62 -29
scripts/training/train_gpt_oss.py
CHANGED
|
@@ -19,6 +19,11 @@ except Exception: # pragma: no cover - optional import depending on TRL version
|
|
| 19 |
DPOTrainer = None
|
| 20 |
from datasets import load_dataset
|
| 21 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Ensure project root and config package are importable for configs that do `from config...` imports
|
| 24 |
project_root = Path(__file__).resolve().parents[2]
|
|
@@ -876,6 +881,23 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 876 |
# Setup Trackio tracking
|
| 877 |
trackio_client = setup_trackio_tracking(config)
|
| 878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
# Create SFT configuration
|
| 880 |
sft_config = create_sft_config(config, output_dir)
|
| 881 |
|
|
@@ -949,6 +971,10 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 949 |
if "packing" in sft_params:
|
| 950 |
sft_kwargs["packing"] = getattr(config, 'packing', False)
|
| 951 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
# Remove any None values
|
| 953 |
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
| 954 |
|
|
@@ -959,7 +985,15 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 959 |
|
| 960 |
# Start training
|
| 961 |
print("Starting GPT-OSS training...")
|
| 962 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 963 |
|
| 964 |
# Save model
|
| 965 |
print("Saving trained model...")
|
|
@@ -970,6 +1004,18 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 970 |
print("Pushing model to Hugging Face Hub...")
|
| 971 |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
|
| 972 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
print("GPT-OSS training completed successfully!")
|
| 974 |
|
| 975 |
return trainer
|
|
|
|
| 19 |
DPOTrainer = None
|
| 20 |
from datasets import load_dataset
|
| 21 |
from pathlib import Path
|
| 22 |
+
# Import monitoring utilities from project src for persistent logging
|
| 23 |
+
try:
|
| 24 |
+
from src.monitoring import create_monitor_from_config # type: ignore
|
| 25 |
+
except Exception:
|
| 26 |
+
create_monitor_from_config = None # type: ignore
|
| 27 |
|
| 28 |
# Ensure project root and config package are importable for configs that do `from config...` imports
|
| 29 |
project_root = Path(__file__).resolve().parents[2]
|
|
|
|
| 881 |
# Setup Trackio tracking
|
| 882 |
trackio_client = setup_trackio_tracking(config)
|
| 883 |
|
| 884 |
+
# Initialize project monitor (HF Datasets + Trackio Space if configured)
|
| 885 |
+
monitor = None
|
| 886 |
+
monitor_callback = None
|
| 887 |
+
if create_monitor_from_config is not None:
|
| 888 |
+
try:
|
| 889 |
+
monitor = create_monitor_from_config(config, experiment_name=experiment_name)
|
| 890 |
+
# Persist configuration immediately
|
| 891 |
+
try:
|
| 892 |
+
cfg_dict = {k: v for k, v in config.__dict__.items() if not k.startswith('_')}
|
| 893 |
+
monitor.log_config(cfg_dict)
|
| 894 |
+
except Exception:
|
| 895 |
+
pass
|
| 896 |
+
# Create callback for SFTTrainer
|
| 897 |
+
monitor_callback = monitor.create_monitoring_callback()
|
| 898 |
+
except Exception:
|
| 899 |
+
monitor = None
|
| 900 |
+
|
| 901 |
# Create SFT configuration
|
| 902 |
sft_config = create_sft_config(config, output_dir)
|
| 903 |
|
|
|
|
| 971 |
if "packing" in sft_params:
|
| 972 |
sft_kwargs["packing"] = getattr(config, 'packing', False)
|
| 973 |
|
| 974 |
+
# Attach monitoring callback if supported
|
| 975 |
+
if "callbacks" in sft_params:
|
| 976 |
+
sft_kwargs["callbacks"] = ([monitor_callback] if monitor_callback is not None else [])
|
| 977 |
+
|
| 978 |
# Remove any None values
|
| 979 |
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
| 980 |
|
|
|
|
| 985 |
|
| 986 |
# Start training
|
| 987 |
print("Starting GPT-OSS training...")
|
| 988 |
+
try:
|
| 989 |
+
trainer.train()
|
| 990 |
+
finally:
|
| 991 |
+
# Ensure periodic metrics are flushed at the end even if interrupted
|
| 992 |
+
try:
|
| 993 |
+
if monitor is not None:
|
| 994 |
+
monitor._save_to_hf_dataset({'status': 'running'})
|
| 995 |
+
except Exception:
|
| 996 |
+
pass
|
| 997 |
|
| 998 |
# Save model
|
| 999 |
print("Saving trained model...")
|
|
|
|
| 1004 |
print("Pushing model to Hugging Face Hub...")
|
| 1005 |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
|
| 1006 |
|
| 1007 |
+
# Log training summary and close monitor
|
| 1008 |
+
try:
|
| 1009 |
+
if monitor is not None:
|
| 1010 |
+
summary = {
|
| 1011 |
+
'output_dir': output_dir,
|
| 1012 |
+
'model_name': getattr(config, 'model_name', 'unknown'),
|
| 1013 |
+
}
|
| 1014 |
+
monitor.log_training_summary(summary)
|
| 1015 |
+
monitor.close()
|
| 1016 |
+
except Exception:
|
| 1017 |
+
pass
|
| 1018 |
+
|
| 1019 |
print("GPT-OSS training completed successfully!")
|
| 1020 |
|
| 1021 |
return trainer
|
src/monitoring.py
CHANGED
|
@@ -50,6 +50,11 @@ class SmolLM3Monitor:
|
|
| 50 |
self.log_artifacts = log_artifacts
|
| 51 |
self.log_metrics_enabled = log_metrics # Rename to avoid conflict
|
| 52 |
self.log_config_enabled = log_config # Rename to avoid conflict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
# HF Datasets configuration
|
| 55 |
self.hf_token = hf_token or os.environ.get('HF_TOKEN')
|
|
@@ -343,12 +348,12 @@ class SmolLM3Monitor:
|
|
| 343 |
|
| 344 |
def log_configuration(self, config: Dict[str, Any]):
|
| 345 |
"""Log experiment configuration"""
|
| 346 |
-
if not self.
|
| 347 |
return
|
| 348 |
|
| 349 |
try:
|
| 350 |
# Log configuration as parameters
|
| 351 |
-
if self.trackio_client:
|
| 352 |
try:
|
| 353 |
result = self.trackio_client.log_parameters(
|
| 354 |
experiment_id=self.experiment_id,
|
|
@@ -390,7 +395,7 @@ class SmolLM3Monitor:
|
|
| 390 |
- throughput, step_time, batch_size, seq_len
|
| 391 |
- token_acc, train/gate_ortho, train/center, etc.
|
| 392 |
"""
|
| 393 |
-
if not self.
|
| 394 |
return
|
| 395 |
|
| 396 |
try:
|
|
@@ -400,7 +405,7 @@ class SmolLM3Monitor:
|
|
| 400 |
metrics['step'] = step
|
| 401 |
|
| 402 |
# Log to Trackio (if available)
|
| 403 |
-
if self.trackio_client:
|
| 404 |
try:
|
| 405 |
result = self.trackio_client.log_metrics(
|
| 406 |
experiment_id=self.experiment_id,
|
|
@@ -418,8 +423,8 @@ class SmolLM3Monitor:
|
|
| 418 |
# Store locally
|
| 419 |
self.metrics_history.append(metrics)
|
| 420 |
|
| 421 |
-
# Save to HF Dataset periodically
|
| 422 |
-
if len(self.metrics_history) %
|
| 423 |
self._save_to_hf_dataset({'metrics': self.metrics_history})
|
| 424 |
|
| 425 |
logger.debug("Metrics logged: %s", metrics)
|
|
@@ -429,7 +434,7 @@ class SmolLM3Monitor:
|
|
| 429 |
|
| 430 |
def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
|
| 431 |
"""Log model checkpoint"""
|
| 432 |
-
if not self.
|
| 433 |
return
|
| 434 |
|
| 435 |
try:
|
|
@@ -441,7 +446,7 @@ class SmolLM3Monitor:
|
|
| 441 |
"checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
|
| 442 |
}
|
| 443 |
|
| 444 |
-
if self.trackio_client:
|
| 445 |
result = self.trackio_client.log_parameters(
|
| 446 |
experiment_id=self.experiment_id,
|
| 447 |
parameters=checkpoint_info
|
|
@@ -453,6 +458,11 @@ class SmolLM3Monitor:
|
|
| 453 |
logger.error("Failed to log checkpoint to Trackio: %s", result)
|
| 454 |
|
| 455 |
self.artifacts.append(checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
logger.info("Checkpoint logged: %s", checkpoint_path)
|
| 457 |
|
| 458 |
except Exception as e:
|
|
@@ -460,9 +470,6 @@ class SmolLM3Monitor:
|
|
| 460 |
|
| 461 |
def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
|
| 462 |
"""Log evaluation results"""
|
| 463 |
-
if not self.enable_tracking:
|
| 464 |
-
return
|
| 465 |
-
|
| 466 |
try:
|
| 467 |
# Add evaluation prefix to metrics
|
| 468 |
eval_metrics = {f"eval_{k}": v for k, v in results.items()}
|
|
@@ -485,9 +492,6 @@ class SmolLM3Monitor:
|
|
| 485 |
|
| 486 |
def log_system_metrics(self, step: Optional[int] = None):
|
| 487 |
"""Log system metrics (GPU, memory, etc.)"""
|
| 488 |
-
if not self.enable_tracking:
|
| 489 |
-
return
|
| 490 |
-
|
| 491 |
try:
|
| 492 |
system_metrics = {}
|
| 493 |
|
|
@@ -513,9 +517,6 @@ class SmolLM3Monitor:
|
|
| 513 |
|
| 514 |
def log_training_summary(self, summary: Dict[str, Any]):
|
| 515 |
"""Log training summary at the end"""
|
| 516 |
-
if not self.enable_tracking:
|
| 517 |
-
return
|
| 518 |
-
|
| 519 |
try:
|
| 520 |
# Add experiment duration
|
| 521 |
end_time = datetime.now()
|
|
@@ -524,7 +525,7 @@ class SmolLM3Monitor:
|
|
| 524 |
summary['experiment_duration_hours'] = duration / 3600
|
| 525 |
|
| 526 |
# Log final summary to Trackio
|
| 527 |
-
if self.trackio_client:
|
| 528 |
result = self.trackio_client.log_parameters(
|
| 529 |
experiment_id=self.experiment_id,
|
| 530 |
parameters=summary
|
|
|
|
| 50 |
self.log_artifacts = log_artifacts
|
| 51 |
self.log_metrics_enabled = log_metrics # Rename to avoid conflict
|
| 52 |
self.log_config_enabled = log_config # Rename to avoid conflict
|
| 53 |
+
# Flush interval for dataset persistence (metrics)
|
| 54 |
+
try:
|
| 55 |
+
self.flush_interval = int(os.environ.get('TRACKIO_FLUSH_INTERVAL', '10'))
|
| 56 |
+
except Exception:
|
| 57 |
+
self.flush_interval = 10
|
| 58 |
|
| 59 |
# HF Datasets configuration
|
| 60 |
self.hf_token = hf_token or os.environ.get('HF_TOKEN')
|
|
|
|
| 348 |
|
| 349 |
def log_configuration(self, config: Dict[str, Any]):
|
| 350 |
"""Log experiment configuration"""
|
| 351 |
+
if not self.log_config_enabled:
|
| 352 |
return
|
| 353 |
|
| 354 |
try:
|
| 355 |
# Log configuration as parameters
|
| 356 |
+
if self.enable_tracking and self.trackio_client:
|
| 357 |
try:
|
| 358 |
result = self.trackio_client.log_parameters(
|
| 359 |
experiment_id=self.experiment_id,
|
|
|
|
| 395 |
- throughput, step_time, batch_size, seq_len
|
| 396 |
- token_acc, train/gate_ortho, train/center, etc.
|
| 397 |
"""
|
| 398 |
+
if not self.log_metrics_enabled:
|
| 399 |
return
|
| 400 |
|
| 401 |
try:
|
|
|
|
| 405 |
metrics['step'] = step
|
| 406 |
|
| 407 |
# Log to Trackio (if available)
|
| 408 |
+
if self.enable_tracking and self.trackio_client:
|
| 409 |
try:
|
| 410 |
result = self.trackio_client.log_metrics(
|
| 411 |
experiment_id=self.experiment_id,
|
|
|
|
| 423 |
# Store locally
|
| 424 |
self.metrics_history.append(metrics)
|
| 425 |
|
| 426 |
+
# Save to HF Dataset periodically (configurable)
|
| 427 |
+
if self.flush_interval > 0 and (len(self.metrics_history) % self.flush_interval == 0):
|
| 428 |
self._save_to_hf_dataset({'metrics': self.metrics_history})
|
| 429 |
|
| 430 |
logger.debug("Metrics logged: %s", metrics)
|
|
|
|
| 434 |
|
| 435 |
def log_model_checkpoint(self, checkpoint_path: str, step: Optional[int] = None):
|
| 436 |
"""Log model checkpoint"""
|
| 437 |
+
if not self.log_artifacts:
|
| 438 |
return
|
| 439 |
|
| 440 |
try:
|
|
|
|
| 446 |
"checkpoint_size": os.path.getsize(checkpoint_path) if os.path.exists(checkpoint_path) else 0
|
| 447 |
}
|
| 448 |
|
| 449 |
+
if self.enable_tracking and self.trackio_client:
|
| 450 |
result = self.trackio_client.log_parameters(
|
| 451 |
experiment_id=self.experiment_id,
|
| 452 |
parameters=checkpoint_info
|
|
|
|
| 458 |
logger.error("Failed to log checkpoint to Trackio: %s", result)
|
| 459 |
|
| 460 |
self.artifacts.append(checkpoint_path)
|
| 461 |
+
# Also preserve checkpoint info in HF dataset
|
| 462 |
+
try:
|
| 463 |
+
self._save_to_hf_dataset({'artifacts': [checkpoint_path], **checkpoint_info})
|
| 464 |
+
except Exception:
|
| 465 |
+
pass
|
| 466 |
logger.info("Checkpoint logged: %s", checkpoint_path)
|
| 467 |
|
| 468 |
except Exception as e:
|
|
|
|
| 470 |
|
| 471 |
def log_evaluation_results(self, results: Dict[str, Any], step: Optional[int] = None):
|
| 472 |
"""Log evaluation results"""
|
|
|
|
|
|
|
|
|
|
| 473 |
try:
|
| 474 |
# Add evaluation prefix to metrics
|
| 475 |
eval_metrics = {f"eval_{k}": v for k, v in results.items()}
|
|
|
|
| 492 |
|
| 493 |
def log_system_metrics(self, step: Optional[int] = None):
|
| 494 |
"""Log system metrics (GPU, memory, etc.)"""
|
|
|
|
|
|
|
|
|
|
| 495 |
try:
|
| 496 |
system_metrics = {}
|
| 497 |
|
|
|
|
| 517 |
|
| 518 |
def log_training_summary(self, summary: Dict[str, Any]):
|
| 519 |
"""Log training summary at the end"""
|
|
|
|
|
|
|
|
|
|
| 520 |
try:
|
| 521 |
# Add experiment duration
|
| 522 |
end_time = datetime.now()
|
|
|
|
| 525 |
summary['experiment_duration_hours'] = duration / 3600
|
| 526 |
|
| 527 |
# Log final summary to Trackio
|
| 528 |
+
if self.enable_tracking and self.trackio_client:
|
| 529 |
result = self.trackio_client.log_parameters(
|
| 530 |
experiment_id=self.experiment_id,
|
| 531 |
parameters=summary
|
src/trainer.py
CHANGED
|
@@ -78,6 +78,7 @@ class SmolLM3Trainer:
|
|
| 78 |
# Add simple console callback for basic monitoring
|
| 79 |
from transformers import TrainerCallback
|
| 80 |
|
|
|
|
| 81 |
class SimpleConsoleCallback(TrainerCallback):
|
| 82 |
def on_init_end(self, args, state, control, **kwargs):
|
| 83 |
"""Called when training initialization is complete"""
|
|
@@ -99,6 +100,16 @@ class SmolLM3Trainer:
|
|
| 99 |
else:
|
| 100 |
lr_str = str(lr)
|
| 101 |
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 104 |
print("🚀 Training started!")
|
|
@@ -109,28 +120,40 @@ class SmolLM3Trainer:
|
|
| 109 |
def on_save(self, args, state, control, **kwargs):
|
| 110 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 111 |
print(f"💾 Checkpoint saved at step {step}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 114 |
if metrics and isinstance(metrics, dict):
|
| 115 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 116 |
eval_loss = metrics.get('eval_loss', 'N/A')
|
| 117 |
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# Add console callback
|
| 120 |
callbacks.append(SimpleConsoleCallback())
|
| 121 |
logger.info("Added simple console monitoring callback")
|
| 122 |
|
| 123 |
-
# Add
|
| 124 |
-
if self.monitor
|
| 125 |
try:
|
| 126 |
trackio_callback = self.monitor.create_monitoring_callback()
|
| 127 |
if trackio_callback:
|
| 128 |
callbacks.append(trackio_callback)
|
| 129 |
-
logger.info("Added
|
| 130 |
else:
|
| 131 |
-
logger.warning("Failed to create
|
| 132 |
except Exception as e:
|
| 133 |
-
logger.error("Error creating
|
| 134 |
logger.info("Continuing with console monitoring only")
|
| 135 |
|
| 136 |
logger.info("Total callbacks: %d", len(callbacks))
|
|
@@ -220,16 +243,20 @@ class SmolLM3Trainer:
|
|
| 220 |
"""Start training"""
|
| 221 |
logger.info("Starting training")
|
| 222 |
|
| 223 |
-
# Log configuration to Trackio
|
| 224 |
-
if self.monitor
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
# Load checkpoint if resuming
|
| 235 |
if self.init_from == "resume":
|
|
@@ -251,17 +278,20 @@ class SmolLM3Trainer:
|
|
| 251 |
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
| 252 |
json.dump(train_result.metrics, f, indent=2)
|
| 253 |
|
| 254 |
-
# Log training summary to Trackio
|
| 255 |
-
if self.monitor
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
# Finish trackio experiment
|
| 267 |
try:
|
|
@@ -276,9 +306,12 @@ class SmolLM3Trainer:
|
|
| 276 |
|
| 277 |
except Exception as e:
|
| 278 |
logger.error("Training failed: %s", e)
|
| 279 |
-
# Close monitoring on error
|
| 280 |
-
if self.monitor
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# Finish trackio experiment on error
|
| 284 |
try:
|
|
|
|
| 78 |
# Add simple console callback for basic monitoring
|
| 79 |
from transformers import TrainerCallback
|
| 80 |
|
| 81 |
+
outer = self
|
| 82 |
class SimpleConsoleCallback(TrainerCallback):
|
| 83 |
def on_init_end(self, args, state, control, **kwargs):
|
| 84 |
"""Called when training initialization is complete"""
|
|
|
|
| 100 |
else:
|
| 101 |
lr_str = str(lr)
|
| 102 |
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
|
| 103 |
+
|
| 104 |
+
# Persist metrics via our monitor when Trackio callback isn't active
|
| 105 |
+
try:
|
| 106 |
+
if outer.monitor:
|
| 107 |
+
# Avoid double logging when Trackio callback is used
|
| 108 |
+
if not outer.monitor.enable_tracking:
|
| 109 |
+
outer.monitor.log_metrics(dict(logs), step if isinstance(step, int) else None)
|
| 110 |
+
outer.monitor.log_system_metrics(step if isinstance(step, int) else None)
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.warning("SimpleConsoleCallback metrics persistence failed: %s", e)
|
| 113 |
|
| 114 |
def on_train_begin(self, args, state, control, **kwargs):
|
| 115 |
print("🚀 Training started!")
|
|
|
|
| 120 |
def on_save(self, args, state, control, **kwargs):
|
| 121 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 122 |
print(f"💾 Checkpoint saved at step {step}")
|
| 123 |
+
try:
|
| 124 |
+
if outer.monitor and not outer.monitor.enable_tracking:
|
| 125 |
+
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{step}")
|
| 126 |
+
if os.path.exists(checkpoint_path):
|
| 127 |
+
outer.monitor.log_model_checkpoint(checkpoint_path, step if isinstance(step, int) else None)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logger.warning("SimpleConsoleCallback checkpoint persistence failed: %s", e)
|
| 130 |
|
| 131 |
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| 132 |
if metrics and isinstance(metrics, dict):
|
| 133 |
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
|
| 134 |
eval_loss = metrics.get('eval_loss', 'N/A')
|
| 135 |
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
|
| 136 |
+
try:
|
| 137 |
+
if outer.monitor and not outer.monitor.enable_tracking:
|
| 138 |
+
outer.monitor.log_evaluation_results(dict(metrics), step if isinstance(step, int) else None)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.warning("SimpleConsoleCallback eval persistence failed: %s", e)
|
| 141 |
|
| 142 |
# Add console callback
|
| 143 |
callbacks.append(SimpleConsoleCallback())
|
| 144 |
logger.info("Added simple console monitoring callback")
|
| 145 |
|
| 146 |
+
# Add monitoring callback if available (always attach; it persists to dataset even if Trackio is disabled)
|
| 147 |
+
if self.monitor:
|
| 148 |
try:
|
| 149 |
trackio_callback = self.monitor.create_monitoring_callback()
|
| 150 |
if trackio_callback:
|
| 151 |
callbacks.append(trackio_callback)
|
| 152 |
+
logger.info("Added monitoring callback")
|
| 153 |
else:
|
| 154 |
+
logger.warning("Failed to create monitoring callback")
|
| 155 |
except Exception as e:
|
| 156 |
+
logger.error("Error creating monitoring callback: %s", e)
|
| 157 |
logger.info("Continuing with console monitoring only")
|
| 158 |
|
| 159 |
logger.info("Total callbacks: %d", len(callbacks))
|
|
|
|
| 243 |
"""Start training"""
|
| 244 |
logger.info("Starting training")
|
| 245 |
|
| 246 |
+
# Log configuration (always persist to dataset; Trackio if enabled)
|
| 247 |
+
if self.monitor:
|
| 248 |
+
try:
|
| 249 |
+
config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')}
|
| 250 |
+
self.monitor.log_config(config_dict)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.warning("Failed to log configuration: %s", e)
|
| 253 |
+
# Log experiment URL only if available
|
| 254 |
+
try:
|
| 255 |
+
experiment_url = self.monitor.get_experiment_url()
|
| 256 |
+
if experiment_url:
|
| 257 |
+
logger.info("Trackio experiment URL: %s", experiment_url)
|
| 258 |
+
except Exception:
|
| 259 |
+
pass
|
| 260 |
|
| 261 |
# Load checkpoint if resuming
|
| 262 |
if self.init_from == "resume":
|
|
|
|
| 278 |
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
|
| 279 |
json.dump(train_result.metrics, f, indent=2)
|
| 280 |
|
| 281 |
+
# Log training summary (always persist to dataset; Trackio if enabled)
|
| 282 |
+
if self.monitor:
|
| 283 |
+
try:
|
| 284 |
+
summary = {
|
| 285 |
+
'final_loss': train_result.metrics.get('train_loss', 0),
|
| 286 |
+
'total_steps': train_result.metrics.get('train_runtime', 0),
|
| 287 |
+
'training_time': train_result.metrics.get('train_runtime', 0),
|
| 288 |
+
'output_dir': self.output_dir,
|
| 289 |
+
'model_name': getattr(self.config, 'model_name', 'unknown'),
|
| 290 |
+
}
|
| 291 |
+
self.monitor.log_training_summary(summary)
|
| 292 |
+
self.monitor.close()
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.warning("Failed to log training summary: %s", e)
|
| 295 |
|
| 296 |
# Finish trackio experiment
|
| 297 |
try:
|
|
|
|
| 306 |
|
| 307 |
except Exception as e:
|
| 308 |
logger.error("Training failed: %s", e)
|
| 309 |
+
# Close monitoring on error (still persist final status to dataset)
|
| 310 |
+
if self.monitor:
|
| 311 |
+
try:
|
| 312 |
+
self.monitor.close(final_status="failed")
|
| 313 |
+
except Exception:
|
| 314 |
+
pass
|
| 315 |
|
| 316 |
# Finish trackio experiment on error
|
| 317 |
try:
|