Claude commited on
Commit
f703ff1
·
unverified ·
1 Parent(s): 46bfd81

Add volume verification, fsync, and stdout fallback for training outputs

Browse files

- Verify volume is mounted and writable at startup (canary file) before
expensive training begins — fails fast with clear error message
- Add fsync after all critical file writes (logs, JSON, report) to ensure
data is flushed to the volume before container termination
- Print full report to stdout after saving so it's always visible in logs
- Save training JSON incrementally after each step (not just at the end)

https://claude.ai/code/session_01DPirJ78YYN4fJUvUFJ5D6V

Files changed (2) hide show
  1. layer1/train.py +43 -0
  2. layer1/training_logger.py +11 -0
layer1/train.py CHANGED
@@ -22,6 +22,7 @@ import json
22
  import logging
23
  import sys
24
  import os
 
25
 
26
  # Auto-load .env for HF_TOKEN
27
  from dotenv import load_dotenv
@@ -40,6 +41,29 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s
40
  logger = logging.getLogger(__name__)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def load_evaluator(
44
  hf_token: str | None = None,
45
  gen_cfg: dict | None = None,
@@ -117,6 +141,13 @@ def _print_config_banner(config: GRPOConfig, report_cfg: dict, paths_cfg: dict):
117
  def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None):
118
  """Run GRPO training."""
119
  _print_config_banner(config, report_cfg, paths_cfg)
 
 
 
 
 
 
 
120
  evaluator = load_evaluator(hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg)
121
  training_logger = TrainingLogger(
122
  log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
@@ -149,6 +180,18 @@ def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: s
149
  )
150
  print(f"\nReport saved to {report_path}")
151
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
154
  """Evaluate a single prompt."""
 
22
  import logging
23
  import sys
24
  import os
25
+ from datetime import datetime
26
 
27
  # Auto-load .env for HF_TOKEN
28
  from dotenv import load_dotenv
 
41
  logger = logging.getLogger(__name__)
42
 
43
 
44
+ def verify_volume_mount(paths_cfg: dict) -> None:
45
+ """Write a canary file at startup to verify the volume is mounted and writable."""
46
+ output_dirs = [
47
+ paths_cfg.get("output_dir", ""),
48
+ paths_cfg.get("log_dir", ""),
49
+ ]
50
+ for d in output_dirs:
51
+ if not d:
52
+ continue
53
+ os.makedirs(d, exist_ok=True)
54
+ canary = os.path.join(d, ".volume_check")
55
+ try:
56
+ with open(canary, "w") as f:
57
+ f.write(f"volume check {datetime.now().isoformat()}\n")
58
+ f.flush()
59
+ os.fsync(f.fileno())
60
+ logger.info("Volume check OK: %s", d)
61
+ except OSError as e:
62
+ logger.error("VOLUME WRITE FAILED for %s: %s", d, e)
63
+ print(f"\n*** WARNING: Cannot write to {d} — volume may not be mounted! ***\n")
64
+ raise
65
+
66
+
67
  def load_evaluator(
68
  hf_token: str | None = None,
69
  gen_cfg: dict | None = None,
 
141
  def run_train(config: GRPOConfig, report_cfg: dict, paths_cfg: dict, hf_token: str | None, gen_cfg: dict | None = None, personas_cfg: dict | None = None):
142
  """Run GRPO training."""
143
  _print_config_banner(config, report_cfg, paths_cfg)
144
+
145
+ # Verify volume is mounted before doing any expensive work
146
+ all_paths = dict(paths_cfg)
147
+ if report_cfg.get("enabled") and report_cfg.get("output_dir"):
148
+ all_paths["report_dir"] = report_cfg["output_dir"]
149
+ verify_volume_mount(all_paths)
150
+
151
  evaluator = load_evaluator(hf_token, gen_cfg=gen_cfg, personas_cfg=personas_cfg)
152
  training_logger = TrainingLogger(
153
  log_dir=paths_cfg["log_dir"], total_steps=config.num_training_steps
 
180
  )
181
  print(f"\nReport saved to {report_path}")
182
 
183
+ # Print report to stdout as fallback (always visible in logs)
184
+ try:
185
+ with open(report_path, "r") as f:
186
+ report_content = f.read()
187
+ print(f"\n{'='*60}")
188
+ print("REPORT CONTENT (stdout fallback)")
189
+ print(f"{'='*60}")
190
+ print(report_content)
191
+ print(f"{'='*60}")
192
+ except OSError:
193
+ print("WARNING: Could not re-read report from disk")
194
+
195
 
196
  def run_eval(hf_token: str | None, prompt: str, episodes: int):
197
  """Evaluate a single prompt."""
layer1/training_logger.py CHANGED
@@ -36,6 +36,8 @@ class TrainingLogger:
36
  with open(self.log_path, "w") as f:
37
  f.write(f"Training Log — {self._start_time.isoformat()}\n")
38
  f.write(f"{'=' * 60}\n\n")
 
 
39
 
40
  def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
41
  """Log a single training iteration (one prompt evaluated)."""
@@ -59,6 +61,11 @@ class TrainingLogger:
59
  f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n")
60
  f.write(f"Episodes: {entry['num_episodes']}\n")
61
  f.write(f"---\n\n")
 
 
 
 
 
62
 
63
  logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
64
 
@@ -76,6 +83,8 @@ class TrainingLogger:
76
  }
77
  with open(self.json_path, "w") as f:
78
  json.dump(data, f, indent=2, default=str)
 
 
79
  logger.info("Training data saved to %s", self.json_path)
80
 
81
  def get_checkpoint_indices(self) -> list[int]:
@@ -418,3 +427,5 @@ class ReportGenerator:
418
 
419
  with open(report_path, "w") as f:
420
  f.write("\n".join(lines))
 
 
 
36
  with open(self.log_path, "w") as f:
37
  f.write(f"Training Log — {self._start_time.isoformat()}\n")
38
  f.write(f"{'=' * 60}\n\n")
39
+ f.flush()
40
+ os.fsync(f.fileno())
41
 
42
  def log_iteration(self, step: int, prompt: str, eval_result: dict[str, Any]):
43
  """Log a single training iteration (one prompt evaluated)."""
 
61
  f.write(f"Min/Max: {entry['min_reward']:.1f} / {entry['max_reward']:.1f}\n")
62
  f.write(f"Episodes: {entry['num_episodes']}\n")
63
  f.write(f"---\n\n")
64
+ f.flush()
65
+ os.fsync(f.fileno())
66
+
67
+ # Incremental save — persist JSON after every step so data survives crashes
68
+ self.save_json()
69
 
70
  logger.info("Logged step %d: mean_reward=%.1f", step, entry["mean_reward"])
71
 
 
83
  }
84
  with open(self.json_path, "w") as f:
85
  json.dump(data, f, indent=2, default=str)
86
+ f.flush()
87
+ os.fsync(f.fileno())
88
  logger.info("Training data saved to %s", self.json_path)
89
 
90
  def get_checkpoint_indices(self) -> list[int]:
 
427
 
428
  with open(report_path, "w") as f:
429
  f.write("\n".join(lines))
430
+ f.flush()
431
+ os.fsync(f.fileno())