lemousehunter commited on
Commit ·
e63fc4e
1
Parent(s): f0b7136
feat: comprehensive W&B logging + retry=0 hardening
Browse filesChanges:
- evaluation.py: Add pre_rank (rank at 1 trace) and max_rank (worst rank)
- mtl_trainer.py: Log per-byte ranks to W&B (ranks/byte_X/{pre,final,min,max,@500,@1000})
- trainer.py: Add GradientNormLogger to single-byte trainer, add pre_rank/max_rank
- schemas.py: Change JobCreate max_retries default from 2 to 0
- database.py: Change DB max_retries default from 2 to 0
- tq.py: Change CLI max_retries defaults from 2 to 0
- agent.py: Change worker max_retries defaults from 3/2 to 0
All changes are additive - backward compatible with existing training runs.
- orchestrator/cli/tq.py +2 -2
- orchestrator/server/database.py +1 -1
- orchestrator/server/schemas.py +1 -1
- orchestrator/worker/agent.py +2 -2
- src/evaluation.py +10 -3
- src/training/mtl_trainer.py +14 -2
- src/training/trainer.py +23 -4
orchestrator/cli/tq.py
CHANGED
|
@@ -314,7 +314,7 @@ def add(config_file: str):
|
|
| 314 |
"name": config.get("name", os.path.basename(config_file)),
|
| 315 |
"config": config.get("config", config),
|
| 316 |
"priority": config.get("priority", 0),
|
| 317 |
-
"max_retries": config.get("max_retries",
|
| 318 |
"tags": config.get("tags", []),
|
| 319 |
}
|
| 320 |
|
|
@@ -344,7 +344,7 @@ def batch(jobs_file: str):
|
|
| 344 |
"name": j.get("name", f"job-{i}"),
|
| 345 |
"config": j.get("config", j),
|
| 346 |
"priority": j.get("priority", 0),
|
| 347 |
-
"max_retries": j.get("max_retries",
|
| 348 |
"tags": j.get("tags", []),
|
| 349 |
}
|
| 350 |
for i, j in enumerate(jobs_list)
|
|
|
|
| 314 |
"name": config.get("name", os.path.basename(config_file)),
|
| 315 |
"config": config.get("config", config),
|
| 316 |
"priority": config.get("priority", 0),
|
| 317 |
+
"max_retries": config.get("max_retries", 0),
|
| 318 |
"tags": config.get("tags", []),
|
| 319 |
}
|
| 320 |
|
|
|
|
| 344 |
"name": j.get("name", f"job-{i}"),
|
| 345 |
"config": j.get("config", j),
|
| 346 |
"priority": j.get("priority", 0),
|
| 347 |
+
"max_retries": j.get("max_retries", 0),
|
| 348 |
"tags": j.get("tags", []),
|
| 349 |
}
|
| 350 |
for i, j in enumerate(jobs_list)
|
orchestrator/server/database.py
CHANGED
|
@@ -61,7 +61,7 @@ def init_db() -> None:
|
|
| 61 |
completed_at TEXT,
|
| 62 |
result TEXT,
|
| 63 |
retry_count INTEGER NOT NULL DEFAULT 0,
|
| 64 |
-
max_retries INTEGER NOT NULL DEFAULT
|
| 65 |
tags TEXT NOT NULL DEFAULT '[]'
|
| 66 |
);
|
| 67 |
|
|
|
|
| 61 |
completed_at TEXT,
|
| 62 |
result TEXT,
|
| 63 |
retry_count INTEGER NOT NULL DEFAULT 0,
|
| 64 |
+
max_retries INTEGER NOT NULL DEFAULT 0,
|
| 65 |
tags TEXT NOT NULL DEFAULT '[]'
|
| 66 |
);
|
| 67 |
|
orchestrator/server/schemas.py
CHANGED
|
@@ -16,7 +16,7 @@ class JobCreate(BaseModel):
|
|
| 16 |
name: str = Field(..., description="Human-readable job name")
|
| 17 |
config: Dict[str, Any] = Field(..., description="Training configuration")
|
| 18 |
priority: int = Field(default=0, description="Job priority (higher = first)")
|
| 19 |
-
max_retries: int = Field(default=
|
| 20 |
tags: List[str] = Field(default_factory=list, description="Tags for filtering")
|
| 21 |
|
| 22 |
|
|
|
|
| 16 |
name: str = Field(..., description="Human-readable job name")
|
| 17 |
config: Dict[str, Any] = Field(..., description="Training configuration")
|
| 18 |
priority: int = Field(default=0, description="Job priority (higher = first)")
|
| 19 |
+
max_retries: int = Field(default=0, description="Max training retries (0 = no retries)")
|
| 20 |
tags: List[str] = Field(default_factory=list, description="Tags for filtering")
|
| 21 |
|
| 22 |
|
orchestrator/worker/agent.py
CHANGED
|
@@ -829,7 +829,7 @@ class WorkerAgent:
|
|
| 829 |
"--desync", str(config.get("desync", 0)),
|
| 830 |
"--variant", str(variant or model_type or "hps"),
|
| 831 |
"--seed", str(config.get("seed", 42)),
|
| 832 |
-
"--max-retries", str(config.get("max_retries",
|
| 833 |
"--data-dir", self.data_dir,
|
| 834 |
"--output-dir", f"/root/jobs/{job_id[:8]}",
|
| 835 |
]
|
|
@@ -911,7 +911,7 @@ class WorkerAgent:
|
|
| 911 |
"--byte", str(config.get("target_byte", 0)),
|
| 912 |
"--desync", str(config.get("desync", 0)),
|
| 913 |
"--seed", str(config.get("seed", 42)),
|
| 914 |
-
"--max-retries", str(config.get("max_retries",
|
| 915 |
"--data-dir", self.data_dir,
|
| 916 |
"--output-dir", f"/root/jobs/{job_id[:8]}",
|
| 917 |
]
|
|
|
|
| 829 |
"--desync", str(config.get("desync", 0)),
|
| 830 |
"--variant", str(variant or model_type or "hps"),
|
| 831 |
"--seed", str(config.get("seed", 42)),
|
| 832 |
+
"--max-retries", str(config.get("max_retries", 0)),
|
| 833 |
"--data-dir", self.data_dir,
|
| 834 |
"--output-dir", f"/root/jobs/{job_id[:8]}",
|
| 835 |
]
|
|
|
|
| 911 |
"--byte", str(config.get("target_byte", 0)),
|
| 912 |
"--desync", str(config.get("desync", 0)),
|
| 913 |
"--seed", str(config.get("seed", 42)),
|
| 914 |
+
"--max-retries", str(config.get("max_retries", 0)),
|
| 915 |
"--data-dir", self.data_dir,
|
| 916 |
"--output-dir", f"/root/jobs/{job_id[:8]}",
|
| 917 |
]
|
src/evaluation.py
CHANGED
|
@@ -178,7 +178,8 @@ def evaluate_model(
|
|
| 178 |
|
| 179 |
Returns:
|
| 180 |
Dictionary with evaluation results:
|
| 181 |
-
'final_rank', 'ranks', '
|
|
|
|
| 182 |
"""
|
| 183 |
if cached_predictions is not None:
|
| 184 |
raw_predictions = cached_predictions
|
|
@@ -238,19 +239,25 @@ def evaluate_model(
|
|
| 238 |
)
|
| 239 |
|
| 240 |
min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
|
|
|
|
|
|
|
| 241 |
rank_at_500 = _get_rank_at_n(ranks_array, 500)
|
| 242 |
rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
|
| 243 |
|
| 244 |
result = {
|
| 245 |
"final_rank": final_rank,
|
| 246 |
"ranks": ranks_array,
|
|
|
|
| 247 |
"min_rank": min_rank,
|
|
|
|
| 248 |
"rank_at_500": rank_at_500,
|
| 249 |
"rank_at_1000": rank_at_1000,
|
| 250 |
}
|
| 251 |
|
| 252 |
logger.info(
|
| 253 |
-
"Byte %d:
|
| 254 |
-
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
return result
|
|
|
|
| 178 |
|
| 179 |
Returns:
|
| 180 |
Dictionary with evaluation results:
|
| 181 |
+
'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank',
|
| 182 |
+
'rank_at_500', 'rank_at_1000'.
|
| 183 |
"""
|
| 184 |
if cached_predictions is not None:
|
| 185 |
raw_predictions = cached_predictions
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
|
| 242 |
+
max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
|
| 243 |
+
pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256
|
| 244 |
rank_at_500 = _get_rank_at_n(ranks_array, 500)
|
| 245 |
rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
|
| 246 |
|
| 247 |
result = {
|
| 248 |
"final_rank": final_rank,
|
| 249 |
"ranks": ranks_array,
|
| 250 |
+
"pre_rank": pre_rank,
|
| 251 |
"min_rank": min_rank,
|
| 252 |
+
"max_rank": max_rank,
|
| 253 |
"rank_at_500": rank_at_500,
|
| 254 |
"rank_at_1000": rank_at_1000,
|
| 255 |
}
|
| 256 |
|
| 257 |
logger.info(
|
| 258 |
+
"Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, "
|
| 259 |
+
"rank@500=%d, rank@1000=%d",
|
| 260 |
+
target_byte, pre_rank, final_rank, min_rank, max_rank,
|
| 261 |
+
rank_at_500, rank_at_1000,
|
| 262 |
)
|
| 263 |
return result
|
src/training/mtl_trainer.py
CHANGED
|
@@ -967,9 +967,12 @@ class MTLTrainer:
|
|
| 967 |
"all_final_ranks": eval_results["all_final_ranks"],
|
| 968 |
"per_byte_results": {
|
| 969 |
str(k): {
|
|
|
|
| 970 |
"final_rank": v["final_rank"],
|
| 971 |
"min_rank": v["min_rank"],
|
|
|
|
| 972 |
"rank_at_500": v.get("rank_at_500", -1),
|
|
|
|
| 973 |
}
|
| 974 |
for k, v in eval_results["byte_results"].items()
|
| 975 |
},
|
|
@@ -1005,13 +1008,22 @@ class MTLTrainer:
|
|
| 1005 |
if self.wandb_project:
|
| 1006 |
try:
|
| 1007 |
import wandb
|
| 1008 |
-
|
| 1009 |
"max_final_rank": eval_results["max_final_rank"],
|
| 1010 |
"mean_final_rank": eval_results["mean_final_rank"],
|
| 1011 |
"num_rank0": eval_results["num_rank0"],
|
| 1012 |
"final_train_loss": result["final_train_loss"],
|
| 1013 |
"final_val_loss": result["final_val_loss"],
|
| 1014 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
wandb.finish()
|
| 1016 |
except Exception:
|
| 1017 |
pass
|
|
|
|
| 967 |
"all_final_ranks": eval_results["all_final_ranks"],
|
| 968 |
"per_byte_results": {
|
| 969 |
str(k): {
|
| 970 |
+
"pre_rank": v.get("pre_rank", 256),
|
| 971 |
"final_rank": v["final_rank"],
|
| 972 |
"min_rank": v["min_rank"],
|
| 973 |
+
"max_rank": v.get("max_rank", 256),
|
| 974 |
"rank_at_500": v.get("rank_at_500", -1),
|
| 975 |
+
"rank_at_1000": v.get("rank_at_1000", -1),
|
| 976 |
}
|
| 977 |
for k, v in eval_results["byte_results"].items()
|
| 978 |
},
|
|
|
|
| 1008 |
if self.wandb_project:
|
| 1009 |
try:
|
| 1010 |
import wandb
|
| 1011 |
+
rank_metrics = {
|
| 1012 |
"max_final_rank": eval_results["max_final_rank"],
|
| 1013 |
"mean_final_rank": eval_results["mean_final_rank"],
|
| 1014 |
"num_rank0": eval_results["num_rank0"],
|
| 1015 |
"final_train_loss": result["final_train_loss"],
|
| 1016 |
"final_val_loss": result["final_val_loss"],
|
| 1017 |
+
}
|
| 1018 |
+
# Log per-byte rank metrics
|
| 1019 |
+
for byte_idx, byte_res in eval_results["byte_results"].items():
|
| 1020 |
+
rank_metrics[f"ranks/byte_{byte_idx}_pre"] = byte_res.get("pre_rank", 256)
|
| 1021 |
+
rank_metrics[f"ranks/byte_{byte_idx}_final"] = byte_res["final_rank"]
|
| 1022 |
+
rank_metrics[f"ranks/byte_{byte_idx}_min"] = byte_res["min_rank"]
|
| 1023 |
+
rank_metrics[f"ranks/byte_{byte_idx}_max"] = byte_res.get("max_rank", 256)
|
| 1024 |
+
rank_metrics[f"ranks/byte_{byte_idx}_at500"] = byte_res.get("rank_at_500", -1)
|
| 1025 |
+
rank_metrics[f"ranks/byte_{byte_idx}_at1000"] = byte_res.get("rank_at_1000", -1)
|
| 1026 |
+
wandb.log(rank_metrics)
|
| 1027 |
wandb.finish()
|
| 1028 |
except Exception:
|
| 1029 |
pass
|
src/training/trainer.py
CHANGED
|
@@ -173,6 +173,21 @@ class Trainer:
|
|
| 173 |
reinit=True,
|
| 174 |
)
|
| 175 |
callbacks.append(WandbMetricsLogger())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
except ImportError:
|
| 177 |
logger.warning("wandb not installed; skipping logging.")
|
| 178 |
|
|
@@ -216,8 +231,10 @@ class Trainer:
|
|
| 216 |
"final_train_accuracy": float(history.history["accuracy"][-1]),
|
| 217 |
"final_val_loss": float(history.history["val_loss"][-1]),
|
| 218 |
"final_val_accuracy": float(history.history["val_accuracy"][-1]),
|
|
|
|
| 219 |
"post_train_final_rank": eval_result["final_rank"],
|
| 220 |
"post_train_min_rank": eval_result["min_rank"],
|
|
|
|
| 221 |
"post_train_rank_at_500": eval_result["rank_at_500"],
|
| 222 |
"post_train_rank_at_1000": eval_result["rank_at_1000"],
|
| 223 |
}
|
|
@@ -227,10 +244,12 @@ class Trainer:
|
|
| 227 |
try:
|
| 228 |
import wandb
|
| 229 |
wandb.log({
|
| 230 |
-
"
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"
|
|
|
|
|
|
|
| 234 |
})
|
| 235 |
wandb.finish()
|
| 236 |
except Exception:
|
|
|
|
| 173 |
reinit=True,
|
| 174 |
)
|
| 175 |
callbacks.append(WandbMetricsLogger())
|
| 176 |
+
|
| 177 |
+
# Gradient norm logging to W&B
|
| 178 |
+
from ..gradient_logger import GradientNormLogger
|
| 179 |
+
grad_logger = GradientNormLogger(
|
| 180 |
+
val_data=(
|
| 181 |
+
data["atk_traces_reshaped"],
|
| 182 |
+
data["atk_labels"],
|
| 183 |
+
),
|
| 184 |
+
log_every_n_epochs=1,
|
| 185 |
+
batch_size=128,
|
| 186 |
+
)
|
| 187 |
+
callbacks.append(grad_logger)
|
| 188 |
+
logger.info(
|
| 189 |
+
"Gradient norm logging ENABLED for single-byte trainer."
|
| 190 |
+
)
|
| 191 |
except ImportError:
|
| 192 |
logger.warning("wandb not installed; skipping logging.")
|
| 193 |
|
|
|
|
| 231 |
"final_train_accuracy": float(history.history["accuracy"][-1]),
|
| 232 |
"final_val_loss": float(history.history["val_loss"][-1]),
|
| 233 |
"final_val_accuracy": float(history.history["val_accuracy"][-1]),
|
| 234 |
+
"post_train_pre_rank": eval_result.get("pre_rank", 256),
|
| 235 |
"post_train_final_rank": eval_result["final_rank"],
|
| 236 |
"post_train_min_rank": eval_result["min_rank"],
|
| 237 |
+
"post_train_max_rank": eval_result.get("max_rank", 256),
|
| 238 |
"post_train_rank_at_500": eval_result["rank_at_500"],
|
| 239 |
"post_train_rank_at_1000": eval_result["rank_at_1000"],
|
| 240 |
}
|
|
|
|
| 244 |
try:
|
| 245 |
import wandb
|
| 246 |
wandb.log({
|
| 247 |
+
"ranks/pre_rank": eval_result.get("pre_rank", 256),
|
| 248 |
+
"ranks/final_rank": eval_result["final_rank"],
|
| 249 |
+
"ranks/min_rank": eval_result["min_rank"],
|
| 250 |
+
"ranks/max_rank": eval_result.get("max_rank", 256),
|
| 251 |
+
"ranks/rank_at_500": eval_result["rank_at_500"],
|
| 252 |
+
"ranks/rank_at_1000": eval_result["rank_at_1000"],
|
| 253 |
})
|
| 254 |
wandb.finish()
|
| 255 |
except Exception:
|