lemousehunter commited on
Commit
e63fc4e
·
1 Parent(s): f0b7136

feat: comprehensive W&B logging + retry=0 hardening

Browse files

Changes:
- evaluation.py: Add pre_rank (rank at 1 trace) and max_rank (worst rank)
- mtl_trainer.py: Log per-byte ranks to W&B (ranks/byte_X/{pre,final,min,max,@500,@1000})
- trainer.py: Add GradientNormLogger to single-byte trainer, add pre_rank/max_rank
- schemas.py: Change JobCreate max_retries default from 2 to 0
- database.py: Change DB max_retries default from 2 to 0
- tq.py: Change CLI max_retries defaults from 2 to 0
- agent.py: Change worker max_retries defaults from 3/2 to 0

All changes are additive - backward compatible with existing training runs.

orchestrator/cli/tq.py CHANGED
@@ -314,7 +314,7 @@ def add(config_file: str):
314
  "name": config.get("name", os.path.basename(config_file)),
315
  "config": config.get("config", config),
316
  "priority": config.get("priority", 0),
317
- "max_retries": config.get("max_retries", 2),
318
  "tags": config.get("tags", []),
319
  }
320
 
@@ -344,7 +344,7 @@ def batch(jobs_file: str):
344
  "name": j.get("name", f"job-{i}"),
345
  "config": j.get("config", j),
346
  "priority": j.get("priority", 0),
347
- "max_retries": j.get("max_retries", 2),
348
  "tags": j.get("tags", []),
349
  }
350
  for i, j in enumerate(jobs_list)
 
314
  "name": config.get("name", os.path.basename(config_file)),
315
  "config": config.get("config", config),
316
  "priority": config.get("priority", 0),
317
+ "max_retries": config.get("max_retries", 0),
318
  "tags": config.get("tags", []),
319
  }
320
 
 
344
  "name": j.get("name", f"job-{i}"),
345
  "config": j.get("config", j),
346
  "priority": j.get("priority", 0),
347
+ "max_retries": j.get("max_retries", 0),
348
  "tags": j.get("tags", []),
349
  }
350
  for i, j in enumerate(jobs_list)
orchestrator/server/database.py CHANGED
@@ -61,7 +61,7 @@ def init_db() -> None:
61
  completed_at TEXT,
62
  result TEXT,
63
  retry_count INTEGER NOT NULL DEFAULT 0,
64
- max_retries INTEGER NOT NULL DEFAULT 2,
65
  tags TEXT NOT NULL DEFAULT '[]'
66
  );
67
 
 
61
  completed_at TEXT,
62
  result TEXT,
63
  retry_count INTEGER NOT NULL DEFAULT 0,
64
+ max_retries INTEGER NOT NULL DEFAULT 0,
65
  tags TEXT NOT NULL DEFAULT '[]'
66
  );
67
 
orchestrator/server/schemas.py CHANGED
@@ -16,7 +16,7 @@ class JobCreate(BaseModel):
16
  name: str = Field(..., description="Human-readable job name")
17
  config: Dict[str, Any] = Field(..., description="Training configuration")
18
  priority: int = Field(default=0, description="Job priority (higher = first)")
19
- max_retries: int = Field(default=2, description="Max training retries")
20
  tags: List[str] = Field(default_factory=list, description="Tags for filtering")
21
 
22
 
 
16
  name: str = Field(..., description="Human-readable job name")
17
  config: Dict[str, Any] = Field(..., description="Training configuration")
18
  priority: int = Field(default=0, description="Job priority (higher = first)")
19
+ max_retries: int = Field(default=0, description="Max training retries (0 = no retries)")
20
  tags: List[str] = Field(default_factory=list, description="Tags for filtering")
21
 
22
 
orchestrator/worker/agent.py CHANGED
@@ -829,7 +829,7 @@ class WorkerAgent:
829
  "--desync", str(config.get("desync", 0)),
830
  "--variant", str(variant or model_type or "hps"),
831
  "--seed", str(config.get("seed", 42)),
832
- "--max-retries", str(config.get("max_retries", 3)),
833
  "--data-dir", self.data_dir,
834
  "--output-dir", f"/root/jobs/{job_id[:8]}",
835
  ]
@@ -911,7 +911,7 @@ class WorkerAgent:
911
  "--byte", str(config.get("target_byte", 0)),
912
  "--desync", str(config.get("desync", 0)),
913
  "--seed", str(config.get("seed", 42)),
914
- "--max-retries", str(config.get("max_retries", 2)),
915
  "--data-dir", self.data_dir,
916
  "--output-dir", f"/root/jobs/{job_id[:8]}",
917
  ]
 
829
  "--desync", str(config.get("desync", 0)),
830
  "--variant", str(variant or model_type or "hps"),
831
  "--seed", str(config.get("seed", 42)),
832
+ "--max-retries", str(config.get("max_retries", 0)),
833
  "--data-dir", self.data_dir,
834
  "--output-dir", f"/root/jobs/{job_id[:8]}",
835
  ]
 
911
  "--byte", str(config.get("target_byte", 0)),
912
  "--desync", str(config.get("desync", 0)),
913
  "--seed", str(config.get("seed", 42)),
914
+ "--max-retries", str(config.get("max_retries", 0)),
915
  "--data-dir", self.data_dir,
916
  "--output-dir", f"/root/jobs/{job_id[:8]}",
917
  ]
src/evaluation.py CHANGED
@@ -178,7 +178,8 @@ def evaluate_model(
178
 
179
  Returns:
180
  Dictionary with evaluation results:
181
- 'final_rank', 'ranks', 'min_rank', 'rank_at_500', 'rank_at_1000'.
 
182
  """
183
  if cached_predictions is not None:
184
  raw_predictions = cached_predictions
@@ -238,19 +239,25 @@ def evaluate_model(
238
  )
239
 
240
  min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
 
 
241
  rank_at_500 = _get_rank_at_n(ranks_array, 500)
242
  rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
243
 
244
  result = {
245
  "final_rank": final_rank,
246
  "ranks": ranks_array,
 
247
  "min_rank": min_rank,
 
248
  "rank_at_500": rank_at_500,
249
  "rank_at_1000": rank_at_1000,
250
  }
251
 
252
  logger.info(
253
- "Byte %d: final_rank=%d, min_rank=%d, rank@500=%d, rank@1000=%d",
254
- target_byte, final_rank, min_rank, rank_at_500, rank_at_1000,
 
 
255
  )
256
  return result
 
178
 
179
  Returns:
180
  Dictionary with evaluation results:
181
+ 'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank',
182
+ 'rank_at_500', 'rank_at_1000'.
183
  """
184
  if cached_predictions is not None:
185
  raw_predictions = cached_predictions
 
239
  )
240
 
241
  min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
242
+ max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
243
+ pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256
244
  rank_at_500 = _get_rank_at_n(ranks_array, 500)
245
  rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
246
 
247
  result = {
248
  "final_rank": final_rank,
249
  "ranks": ranks_array,
250
+ "pre_rank": pre_rank,
251
  "min_rank": min_rank,
252
+ "max_rank": max_rank,
253
  "rank_at_500": rank_at_500,
254
  "rank_at_1000": rank_at_1000,
255
  }
256
 
257
  logger.info(
258
+ "Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, "
259
+ "rank@500=%d, rank@1000=%d",
260
+ target_byte, pre_rank, final_rank, min_rank, max_rank,
261
+ rank_at_500, rank_at_1000,
262
  )
263
  return result
src/training/mtl_trainer.py CHANGED
@@ -967,9 +967,12 @@ class MTLTrainer:
967
  "all_final_ranks": eval_results["all_final_ranks"],
968
  "per_byte_results": {
969
  str(k): {
 
970
  "final_rank": v["final_rank"],
971
  "min_rank": v["min_rank"],
 
972
  "rank_at_500": v.get("rank_at_500", -1),
 
973
  }
974
  for k, v in eval_results["byte_results"].items()
975
  },
@@ -1005,13 +1008,22 @@ class MTLTrainer:
1005
  if self.wandb_project:
1006
  try:
1007
  import wandb
1008
- wandb.log({
1009
  "max_final_rank": eval_results["max_final_rank"],
1010
  "mean_final_rank": eval_results["mean_final_rank"],
1011
  "num_rank0": eval_results["num_rank0"],
1012
  "final_train_loss": result["final_train_loss"],
1013
  "final_val_loss": result["final_val_loss"],
1014
- })
 
 
 
 
 
 
 
 
 
1015
  wandb.finish()
1016
  except Exception:
1017
  pass
 
967
  "all_final_ranks": eval_results["all_final_ranks"],
968
  "per_byte_results": {
969
  str(k): {
970
+ "pre_rank": v.get("pre_rank", 256),
971
  "final_rank": v["final_rank"],
972
  "min_rank": v["min_rank"],
973
+ "max_rank": v.get("max_rank", 256),
974
  "rank_at_500": v.get("rank_at_500", -1),
975
+ "rank_at_1000": v.get("rank_at_1000", -1),
976
  }
977
  for k, v in eval_results["byte_results"].items()
978
  },
 
1008
  if self.wandb_project:
1009
  try:
1010
  import wandb
1011
+ rank_metrics = {
1012
  "max_final_rank": eval_results["max_final_rank"],
1013
  "mean_final_rank": eval_results["mean_final_rank"],
1014
  "num_rank0": eval_results["num_rank0"],
1015
  "final_train_loss": result["final_train_loss"],
1016
  "final_val_loss": result["final_val_loss"],
1017
+ }
1018
+ # Log per-byte rank metrics
1019
+ for byte_idx, byte_res in eval_results["byte_results"].items():
1020
+ rank_metrics[f"ranks/byte_{byte_idx}_pre"] = byte_res.get("pre_rank", 256)
1021
+ rank_metrics[f"ranks/byte_{byte_idx}_final"] = byte_res["final_rank"]
1022
+ rank_metrics[f"ranks/byte_{byte_idx}_min"] = byte_res["min_rank"]
1023
+ rank_metrics[f"ranks/byte_{byte_idx}_max"] = byte_res.get("max_rank", 256)
1024
+ rank_metrics[f"ranks/byte_{byte_idx}_at500"] = byte_res.get("rank_at_500", -1)
1025
+ rank_metrics[f"ranks/byte_{byte_idx}_at1000"] = byte_res.get("rank_at_1000", -1)
1026
+ wandb.log(rank_metrics)
1027
  wandb.finish()
1028
  except Exception:
1029
  pass
src/training/trainer.py CHANGED
@@ -173,6 +173,21 @@ class Trainer:
173
  reinit=True,
174
  )
175
  callbacks.append(WandbMetricsLogger())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  except ImportError:
177
  logger.warning("wandb not installed; skipping logging.")
178
 
@@ -216,8 +231,10 @@ class Trainer:
216
  "final_train_accuracy": float(history.history["accuracy"][-1]),
217
  "final_val_loss": float(history.history["val_loss"][-1]),
218
  "final_val_accuracy": float(history.history["val_accuracy"][-1]),
 
219
  "post_train_final_rank": eval_result["final_rank"],
220
  "post_train_min_rank": eval_result["min_rank"],
 
221
  "post_train_rank_at_500": eval_result["rank_at_500"],
222
  "post_train_rank_at_1000": eval_result["rank_at_1000"],
223
  }
@@ -227,10 +244,12 @@ class Trainer:
227
  try:
228
  import wandb
229
  wandb.log({
230
- "final_rank": eval_result["final_rank"],
231
- "min_rank": eval_result["min_rank"],
232
- "rank_at_500": eval_result["rank_at_500"],
233
- "rank_at_1000": eval_result["rank_at_1000"],
 
 
234
  })
235
  wandb.finish()
236
  except Exception:
 
173
  reinit=True,
174
  )
175
  callbacks.append(WandbMetricsLogger())
176
+
177
+ # Gradient norm logging to W&B
178
+ from ..gradient_logger import GradientNormLogger
179
+ grad_logger = GradientNormLogger(
180
+ val_data=(
181
+ data["atk_traces_reshaped"],
182
+ data["atk_labels"],
183
+ ),
184
+ log_every_n_epochs=1,
185
+ batch_size=128,
186
+ )
187
+ callbacks.append(grad_logger)
188
+ logger.info(
189
+ "Gradient norm logging ENABLED for single-byte trainer."
190
+ )
191
  except ImportError:
192
  logger.warning("wandb not installed; skipping logging.")
193
 
 
231
  "final_train_accuracy": float(history.history["accuracy"][-1]),
232
  "final_val_loss": float(history.history["val_loss"][-1]),
233
  "final_val_accuracy": float(history.history["val_accuracy"][-1]),
234
+ "post_train_pre_rank": eval_result.get("pre_rank", 256),
235
  "post_train_final_rank": eval_result["final_rank"],
236
  "post_train_min_rank": eval_result["min_rank"],
237
+ "post_train_max_rank": eval_result.get("max_rank", 256),
238
  "post_train_rank_at_500": eval_result["rank_at_500"],
239
  "post_train_rank_at_1000": eval_result["rank_at_1000"],
240
  }
 
244
  try:
245
  import wandb
246
  wandb.log({
247
+ "ranks/pre_rank": eval_result.get("pre_rank", 256),
248
+ "ranks/final_rank": eval_result["final_rank"],
249
+ "ranks/min_rank": eval_result["min_rank"],
250
+ "ranks/max_rank": eval_result.get("max_rank", 256),
251
+ "ranks/rank_at_500": eval_result["rank_at_500"],
252
+ "ranks/rank_at_1000": eval_result["rank_at_1000"],
253
  })
254
  wandb.finish()
255
  except Exception: